diff --git a/bin/nnc/laplacian_test.cpp b/bin/nnc/laplacian_test.cpp new file mode 100644 index 000000000..abd55d682 --- /dev/null +++ b/bin/nnc/laplacian_test.cpp @@ -0,0 +1,523 @@ +extern "C" { +#include +#include +#include +#include +} +#include "nnc/mfa/v2/ShaderCache.hpp" +#include "nnc/mfa/v2/GEMMDescriptor.hpp" +#include "nnc/mfa/v2/GEMMKernelDescriptor.hpp" +#include "nnc/mfa/v2/GEMMKernel.hpp" +#include "3rdparty/dsfmt/dSFMT.h" +#include + +ShaderCache shaderCache; + +std::pair profileProblemSize(GEMMDescriptor descriptor) +{ + const int problemSize = descriptor.matrixDimensions[0]; + + // Allocate FP32 memory for the operands. + float* A = (float*)ccmalloc(sizeof(float) * problemSize * problemSize); + float* B = (float*)ccmalloc(sizeof(float) * problemSize * problemSize); + float* C = (float*)ccmalloc(sizeof(float) * problemSize * problemSize); + float* bias = (float*)ccmalloc(sizeof(float) * problemSize); + + // Initialize A as the 2nd-order periodic Laplacian. + int i, j; + for (i = 0; i < problemSize; i++) + for (j = 0; j < problemSize; j++) + A[i * problemSize + j] = 0; + for (i = 0; i < problemSize; i++) + { + const int diagonalAddress = i * problemSize + i; + A[diagonalAddress] = -2; + + const int leftColumnID = (i + problemSize - 1) % problemSize; + const int leftSubDiagonalAddress = i * problemSize + leftColumnID; + A[leftSubDiagonalAddress] = 1; + + const int rightColumnID = (i + problemSize + 1) % problemSize; + const int rightSubDiagonalAddress = i * problemSize + rightColumnID; + A[rightSubDiagonalAddress] = 1; + } + + dsfmt_t dsfmt; + dsfmt_init_gen_rand(&dsfmt, 1); + // Initialize B to random numbers. + for (int rowID = 0; rowID < problemSize; rowID++) + { + for (int columnID = 0; columnID < problemSize; columnID++) + { + const int address = rowID * problemSize + columnID; + B[address] = dsfmt_genrand_open_close(&dsfmt); + } + } + + // Initialize C to random numbers. + for (int rowID = 0; rowID < problemSize; rowID++) + { + bias[rowID] = dsfmt_genrand_open_close(&dsfmt); + } + void* A_storage = nullptr; + if (descriptor.memoryPrecisions.A == GEMMOperandPrecision::FP16) + { + A_storage = (uint16_t*)ccmalloc(sizeof(uint16_t) * problemSize * problemSize); + ccv_float_to_half_precision(A, (uint16_t*)A_storage, problemSize * problemSize); + void* t = A_storage; + A_storage = A; + A = (float*)t; + } else if (descriptor.memoryPrecisions.A == GEMMOperandPrecision::BF16) { + A_storage = (uint16_t*)ccmalloc(sizeof(uint16_t) * problemSize * problemSize); + for (int i = 0; i < problemSize * problemSize; i++) + ((uint16_t*)A_storage)[i] = ((uint16_t*)A)[i * 2 + 1]; + void* t = A_storage; + A_storage = A; + A = (float*)t; + } + void* B_storage = nullptr; + if (descriptor.memoryPrecisions.B == GEMMOperandPrecision::FP16) + { + B_storage = (uint16_t*)ccmalloc(sizeof(uint16_t) * problemSize * problemSize); + ccv_float_to_half_precision(B, (uint16_t*)B_storage, problemSize * problemSize); + void* t = B_storage; + B_storage = B; + B = (float*)t; + } else if (descriptor.memoryPrecisions.B == GEMMOperandPrecision::BF16) { + B_storage = (uint16_t*)ccmalloc(sizeof(uint16_t) * problemSize * problemSize); + for (int i = 0; i < problemSize * problemSize; i++) + ((uint16_t*)B_storage)[i] = ((uint16_t*)B)[i * 2 + 1]; + void* t = B_storage; + B_storage = B; + B = (float*)t; + } + void* bias_storage = nullptr; + if (descriptor.memoryPrecisions.bias == GEMMOperandPrecision::FP16) + { + bias_storage = (uint16_t*)ccmalloc(sizeof(uint16_t) * problemSize); + ccv_float_to_half_precision(bias, (uint16_t*)bias_storage, problemSize); + void* t = bias_storage; + bias_storage = bias; + bias = (float*)t; + } else if (descriptor.memoryPrecisions.bias == GEMMOperandPrecision::BF16) { + bias_storage = (uint16_t*)ccmalloc(sizeof(uint16_t) * problemSize); + for (int i = 0; i < problemSize; i++) + ((uint16_t*)bias_storage)[i] = ((uint16_t*)bias)[i * 2 + 1]; + void* t = bias_storage; + bias_storage = bias; + bias = (float*)t; + } + + // Since the Laplacian is symmetric, we swap roles of the matrices to test + // transposition of the left-hand side. + // + // Note that the test cannot cover correctness of A and B transposition + // simultaneously. Instead, test the correctness in isolation + // (AB, AB^T, A^T B). Performance can be tested in all four permutations + // (AB, AB^T, A^T B, A^T B^T). + if (descriptor.transposeState[0]) + { + float* t = A; + A = B; + B = t; + void* t_storage = A_storage; + A_storage = B_storage; + B_storage = t_storage; + } + + // Multiply A with B. + int maxGFLOPS = 0; + int occupancy = 0; + DeviceProperties dprops; + dprops.coreCount = 18; + NS::SharedPtr device = NS::TransferPtr(MTL::CreateSystemDefaultDevice()); + NS::SharedPtr queue = NS::TransferPtr(device->newCommandQueue()); + { + // Generate the kernel. + auto pipelineValue = shaderCache.findKernel(descriptor, device.get(), dprops); + occupancy = pipelineValue->pipeline->maxTotalThreadsPerThreadgroup(); + NS::SharedPtr bufferA = NS::TransferPtr(device->newBuffer(A, descriptor.memoryPrecisions.A.size() * problemSize * problemSize, MTL::ResourceStorageModeShared | MTL::ResourceHazardTrackingModeTracked)); + NS::SharedPtr bufferB = NS::TransferPtr(device->newBuffer(B, descriptor.memoryPrecisions.B.size() * problemSize * problemSize, MTL::ResourceStorageModeShared | MTL::ResourceHazardTrackingModeTracked)); + NS::SharedPtr bufferC = NS::TransferPtr(device->newBuffer(C, descriptor.memoryPrecisions.C.size() * problemSize * problemSize, MTL::ResourceStorageModeShared | MTL::ResourceHazardTrackingModeTracked)); + NS::SharedPtr bufferBias = NS::TransferPtr(device->newBuffer(bias, descriptor.memoryPrecisions.bias.size() * problemSize, MTL::ResourceStorageModeShared | MTL::ResourceHazardTrackingModeTracked)); + + // load = directAccessCondition, + // store = false + // problemSize = 1488 | A B | 832 threads/core | 8175 GFLOPS + // problemSize = 1488 | A B^T | 1024 threads/core | 8712 GFLOPS + // problemSize = 1488 | A^T B | 1024 threads/core | 8818 GFLOPS + // problemSize = 1488 | A^T B^T | 1024 threads/core | 8972 GFLOPS + // problemSize = 1489 | A B | 768 threads/core | 7888 GFLOPS + // problemSize = 1489 | A B^T | 768 threads/core | 8256 GFLOPS + // problemSize = 1489 | A^T B | 768 threads/core | 8026 GFLOPS + // problemSize = 1489 | A^T B^T | 832 threads/core | 8463 GFLOPS + // + // load = directAccessCondition + // store = directAccessCondition && (gid.y * M_group < M_edge) && (gid.x * N_group < N_edge) + // problemSize = 1488 | A B | 832 threads/core | 8186 GFLOPS + // problemSize = 1488 | A B^T | 1024 threads/core | 8709 GFLOPS + // problemSize = 1488 | A^T B | 1024 threads/core | 8808 GFLOPS + // problemSize = 1488 | A^T B^T | 1024 threads/core | 8984 GFLOPS + // problemSize = 1489 | A B | 768 threads/core | 7902 GFLOPS + // problemSize = 1489 | A B^T | 768 threads/core | 8249 GFLOPS + // problemSize = 1489 | A^T B | 768 threads/core | 8034 GFLOPS + // problemSize = 1489 | A^T B^T | 832 threads/core | 8469 GFLOPS + // + // load = directAccessCondition && (gid.y * M_group < M_edge) && (gid.x * N_group < N_edge) + // store = directAccessCondition && (gid.y * M_group < M_edge) && (gid.x * N_group < N_edge) + // problemSize = 1488 | A B | 832 threads/core | 8181 GFLOPS + // problemSize = 1488 | A B^T | 1024 threads/core | 8710 GFLOPS + // problemSize = 1488 | A^T B | 1024 threads/core | 8806 GFLOPS + // problemSize = 1488 | A^T B^T | 1024 threads/core | 8979 GFLOPS + // problemSize = 1489 | A B | 768 threads/core | 7892 GFLOPS + // problemSize = 1489 | A B^T | 768 threads/core | 8242 GFLOPS + // problemSize = 1489 | A^T B | 768 threads/core | 8034 GFLOPS + // problemSize = 1489 | A^T B^T | 832 threads/core | 8461 GFLOPS + // + // load previous C = false (M1 Max) + // problemSize = 1488 | A B | 896 threads/core | 8358 GFLOPS + // problemSize = 1488 | A B^T | 1024 threads/core | 8682 GFLOPS + // problemSize = 1488 | A^T B | 1024 threads/core | 8803 GFLOPS + // problemSize = 1488 | A^T B^T | 1024 threads/core | 9024 GFLOPS + // problemSize = 1489 | A B | 768 threads/core | 8039 GFLOPS + // problemSize = 1489 | A B^T | 832 threads/core | 8376 GFLOPS + // problemSize = 1489 | A^T B | 832 threads/core | 8374 GFLOPS + // problemSize = 1489 | A^T B^T | 832 threads/core | 8654 GFLOPS + // + // load previous C = true (M1 Max) + // problemSize = 1488 | A B | 896 threads/core | 8352 GFLOPS + // problemSize = 1488 | A B^T | 896 threads/core | 8515 GFLOPS + // problemSize = 1488 | A^T B | 1024 threads/core | 8760 GFLOPS + // problemSize = 1488 | A^T B^T | 1024 threads/core | 9007 GFLOPS + // problemSize = 1489 | A B | 768 threads/core | 7917 GFLOPS + // problemSize = 1489 | A B^T | 768 threads/core | 7992 GFLOPS + // problemSize = 1489 | A^T B | 832 threads/core | 8185 GFLOPS + // problemSize = 1489 | A^T B^T | 832 threads/core | 8583 GFLOPS + // + // load previous C = false (M4) + // problemSize = 1488 | A B | 1024 threads/core | 3353 GFLOPS + // problemSize = 1488 | A B^T | 1024 threads/core | 3324 GFLOPS + // problemSize = 1488 | A^T B | 1024 threads/core | 3338 GFLOPS + // problemSize = 1488 | A^T B^T | 1024 threads/core | 3289 GFLOPS + // problemSize = 1489 | A B | 1024 threads/core | 3375 GFLOPS + // problemSize = 1489 | A B^T | 1024 threads/core | 3317 GFLOPS + // problemSize = 1489 | A^T B | 1024 threads/core | 3343 GFLOPS + // problemSize = 1489 | A^T B^T | 1024 threads/core | 3298 GFLOPS + // + // load previous C = true (M4) + // problemSize = 1488 | A B | 1024 threads/core | 3374 GFLOPS + // problemSize = 1488 | A B^T | 1024 threads/core | 3312 GFLOPS + // problemSize = 1488 | A^T B | 1024 threads/core | 3321 GFLOPS + // problemSize = 1488 | A^T B^T | 1024 threads/core | 3249 GFLOPS + // problemSize = 1489 | A B | 1024 threads/core | 3323 GFLOPS + // problemSize = 1489 | A B^T | 1024 threads/core | 3280 GFLOPS + // problemSize = 1489 | A^T B | 1024 threads/core | 3308 GFLOPS + // problemSize = 1489 | A^T B^T | 1024 threads/core | 3256 GFLOPS + + // Profile the latency of matrix multiplication. + for (int i = 0; i < 15; i++) + { + const int duplicatedCommandCount = 20; + NS::SharedPtr commandBuffer = NS::TransferPtr(queue->commandBuffer()); + NS::SharedPtr encoder = NS::TransferPtr(commandBuffer->computeCommandEncoder()); + encoder->setComputePipelineState(pipelineValue->pipeline.get()); + encoder->setThreadgroupMemoryLength(pipelineValue->kernel->threadgroupMemoryAllocation, 0); + encoder->setBuffer(bufferA.get(), 0, 0); + encoder->setBuffer(bufferB.get(), 0, 1); + encoder->setBuffer(bufferC.get(), 0, 2); + encoder->useResource(bufferA.get(), MTL::ResourceUsageRead); + encoder->useResource(bufferB.get(), MTL::ResourceUsageRead); + encoder->useResource(bufferC.get(), MTL::ResourceUsageWrite); + if (descriptor.useBias) + { + encoder->setBuffer(bufferBias.get(), 0, 3); + encoder->useResource(bufferBias.get(), MTL::ResourceUsageRead); + } + for (int j = 0; j < duplicatedCommandCount; j++) + { + auto ceilDivide = + [=](int64_t target, uint16_t granularity) -> int64_t { + return (target + int64_t(granularity) - 1) / int64_t(granularity); + }; + MTL::Size gridSize = MTL::Size(ceilDivide(problemSize, pipelineValue->kernel->blockDimensions[1]), ceilDivide(problemSize, pipelineValue->kernel->blockDimensions[0]), 1); + MTL::Size groupSize = MTL::Size(pipelineValue->kernel->threadgroupSize, 1, 1); + encoder->dispatchThreadgroups(gridSize, groupSize); + } + encoder->endEncoding(); + commandBuffer->commit(); + commandBuffer->waitUntilCompleted(); + auto start = commandBuffer->GPUStartTime(); + auto end = commandBuffer->GPUEndTime(); + auto latency = end - start; + + // Determine the amount of work done. + auto operations = (int64_t)2 * problemSize * problemSize * problemSize; + operations = operations * duplicatedCommandCount; + auto gflops = (int)((double)operations / (double)latency / 1e9); + + // Report the results. + // let latencyMicroseconds = Int(latency / 1e-6) + // print(latencyMicroseconds, "μs", gflops, "GFLOPS") + maxGFLOPS = std::max(maxGFLOPS, gflops); + } + // Copy the results to C. + { + auto precision = descriptor.memoryPrecisions.C; + auto raw = bufferC->contents(); + for (int rowID = 0; rowID < problemSize; rowID++) + { + for (int columnID = 0; columnID < problemSize; columnID++) + { + const int address = rowID * problemSize + columnID; + float entry32; + switch (precision.value) { + case GEMMOperandPrecision::FP32: + entry32 = ((float*)raw)[address]; + break; + case GEMMOperandPrecision::FP16: { + uint16_t value = ((uint16_t*)raw)[address]; + ccv_half_precision_to_float(&value, &entry32, 1); + break; + } + case GEMMOperandPrecision::BF16: { + uint16_t value[2]; + value[0] = 0; + value[1] = ((uint16_t*)raw)[address]; + entry32 = *(float*)value; + } + } + C[address] = entry32; + } + } + } + } + + // Choose an error threshold. + auto createErrorThreshold = + [=](GEMMOperandPrecision precision) -> float { + switch (precision.value) { + case GEMMOperandPrecision::FP32: + return 1e-5; + case GEMMOperandPrecision::FP16: + return 5e-3; + case GEMMOperandPrecision::BF16: + return 5e-2; + } + }; + float errorThreshold = 0; + { + auto memoryPrecisions = descriptor.memoryPrecisions; + auto thresholdA = createErrorThreshold(memoryPrecisions.A); + auto thresholdB = createErrorThreshold(memoryPrecisions.B); + auto thresholdC = createErrorThreshold(memoryPrecisions.C); + errorThreshold = std::max(errorThreshold, thresholdA); + errorThreshold = std::max(errorThreshold, thresholdB); + errorThreshold = std::max(errorThreshold, thresholdC); + } + // Check the results. + int errorCount = 0; + if (A_storage != nullptr) + { + void* t = A_storage; + A_storage = A; + A = (float*)t; + } + if (B_storage != nullptr) + { + void* t = B_storage; + B_storage = B; + B = (float*)t; + } + if (bias_storage != nullptr) + { + void* t = bias_storage; + bias_storage = bias; + bias = (float*)t; + } + for (int m = 0; m < problemSize; m++) + { + for (int n = 0; n < problemSize; n++) + { + // Find the source row IDs. + int leftRowID = (m + problemSize - 1) % problemSize; + int centerRowID = m; + int rightRowID = (m + problemSize + 1) % problemSize; + + // Find the source scalars. + float leftSource; + float centerSource; + float rightSource; + float biasSource; + if (descriptor.transposeState[0]) + { + leftSource = A[leftRowID * problemSize + n]; + centerSource = A[centerRowID * problemSize + n]; + rightSource = A[rightRowID * problemSize + n]; + biasSource = descriptor.useBias ? bias[n] : 0; + } else if (descriptor.transposeState[1]) { + leftSource = B[n * problemSize + leftRowID]; + centerSource = B[n * problemSize + centerRowID]; + rightSource = B[n * problemSize + rightRowID]; + biasSource = descriptor.useBias ? bias[n] : 0; + } else { + leftSource = B[leftRowID * problemSize + n]; + centerSource = B[centerRowID * problemSize + n]; + rightSource = B[rightRowID * problemSize + n]; + biasSource = descriptor.useBias ? bias[n] : 0; + } + + // Find the expected result. + float expected = leftSource - 2 * centerSource + rightSource + biasSource; + + // Find the actual result. + float actual; + if (descriptor.transposeState[0]) + { + actual = C[n * problemSize + m]; + } else { + actual = C[m * problemSize + n]; + } + + // Report whether it is correct. + float error = fabs(expected - actual); + if (error > errorThreshold) + { + if (errorCount < 10) + { + printf("error: %f / ~1.000\n", error); + errorCount += 1; + } + } + } + } + ccfree(A); + ccfree(B); + ccfree(C); + ccfree(bias); + if (A_storage != nullptr) + ccfree(A_storage); + if (B_storage != nullptr) + ccfree(B_storage); + if (bias_storage != nullptr) + ccfree(bias_storage); + return std::make_pair(maxGFLOPS, occupancy); +} + +struct TestDescriptor { + GEMMOperandPrecision precision; + int problemSize; + bool transposeState[2]; + bool useBias; +}; + +void runTest(TestDescriptor descriptor) +{ + // Set up the kernel. + GEMMDescriptor gemmDesc = GEMMDescriptor(); + auto precision = descriptor.precision; + unsigned int n = (unsigned int)descriptor.problemSize; + gemmDesc.matrixDimensions = simd::uint3 { n, n, n }; + gemmDesc.memoryPrecisions = { + .A = precision, .B = precision, .C = precision, .bias = precision + }; + gemmDesc.transposeState = simd::uchar3 { descriptor.transposeState[0], descriptor.transposeState[1], descriptor.transposeState[0] }; + gemmDesc.useBias = descriptor.useBias; + + // Test the kernel. + auto statistic = profileProblemSize(gemmDesc); + + // Report the results. + std::cout << "problemSize = " << descriptor.problemSize << " | "; + if (descriptor.transposeState[0]) + { + std::cout << "A^T "; + } else { + std::cout << "A "; + } + if (descriptor.transposeState[1]) + { + std::cout << "B^T | "; + } else { + std::cout << "B | "; + } + + std::cout << statistic.first << " GFLOPS " << statistic.second << " threads/core | " << std::endl; +} + +int main(int argc, char** argv) +{ + ccv_nnc_init(); + { + int problemSizes[] = { + 7, 8, 9, 10, + 15, 16, 17, 18, + 23, 24, 25, + 31, 32, 33, + 47, 48, 49, + 63, 64, 65, + 103, 104, 112, + 126, 127, 128, 129, + 130, 131, + 135, 136, 137, + 143, 144, 145, + 151, 152, 153, + }; + bool transposeStates[] = { + false, false, + false, true, + true, false, + }; + printf("Correctness tests:\n"); + for (int i = 0; i < sizeof(problemSizes) / sizeof(int); i++) + { + for (int j = 0; j < sizeof(transposeStates) / (sizeof(bool) * 2); j++) + { + TestDescriptor testDescriptor = TestDescriptor(); + testDescriptor.precision = GEMMOperandPrecision::FP32; + testDescriptor.problemSize = problemSizes[i]; + testDescriptor.transposeState[0] = transposeStates[j * 2]; + testDescriptor.transposeState[1] = transposeStates[j * 2 + 1]; + testDescriptor.useBias = false; + runTest(testDescriptor); + } + } + } + { + bool transposeStates[] = { + false, false, + false, true, + true, false, + true, true, + false, false, + false, true, + true, false, + true, true, + }; + bool useBias[] = { + false, + false, + false, + false, + true, + true, + true, + true + }; + + printf("\nPerformance tests:\n"); + for (int problemSize = 3072; problemSize <= 3072; problemSize++) + { + for (int j = 0; j < sizeof(transposeStates) / (sizeof(bool) * 2); j++) + { + TestDescriptor testDescriptor = TestDescriptor(); + testDescriptor.precision = GEMMOperandPrecision::BF16; + testDescriptor.problemSize = problemSize; + testDescriptor.transposeState[0] = transposeStates[j * 2]; + testDescriptor.transposeState[1] = transposeStates[j * 2 + 1]; + testDescriptor.useBias = useBias[j]; + runTest(testDescriptor); + } + } + } + return 0; +} diff --git a/bin/nnc/makefile b/bin/nnc/makefile index 09feaa39a..422743656 100644 --- a/bin/nnc/makefile +++ b/bin/nnc/makefile @@ -4,7 +4,7 @@ LDFLAGS := -L"../../lib" -lccv $(LDFLAGS) CFLAGS := -O3 -Wall -I"../../lib" $(CFLAGS) NVFLAGS := -O3 -I"../../lib" -lineinfo $(NVFLAGS) -TARGETS = nnc-e2e-verify nnc-e2e-sym-verify nnc-sym cifar-10 imagenet coco imdb iwslt wmt csv imdb_lstm +TARGETS = nnc-e2e-verify nnc-e2e-sym-verify nnc-sym cifar-10 imagenet coco imdb iwslt wmt csv imdb_lstm laplacian_test FUZZ_TARGETS = csv_fuzz @@ -37,6 +37,9 @@ libccv.a: %.o: %.c $(CC) $< -o $@ -c $(CFLAGS) +laplacian_test.o: laplacian_test.cpp + $(CC) $< -o $@ -c $(CFLAGS) -std=c++17 + .gitignore: echo $(TARGETS) | tr ' ' '\n' > .gitignore diff --git a/lib/configure b/lib/configure index 224cc6d4b..8329af1b5 100755 --- a/lib/configure +++ b/lib/configure @@ -4760,7 +4760,7 @@ if test "$mps_support" = yes; then printf "%s\n" "yes" >&6; } DEFINE_MACROS="$DEFINE_MACROS-D HAVE_MPS " - MKLDFLAGS="$MKLDFLAGS-framework MetalPerformanceShaders -framework MetalPerformanceShadersGraph -framework Foundation -framework Metal -lc++ " + MKLDFLAGS="$MKLDFLAGS-framework MetalPerformanceShaders -framework MetalPerformanceShadersGraph -framework Foundation -framework Metal -framework OpenCL -lc++ " CUDA_SRCS="" diff --git a/lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m b/lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m index 48d02d2f7..c65bda63c 100644 --- a/lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m +++ b/lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m @@ -175,10 +175,9 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint ccv_nnc_mfa_context_t* context = ccv_nnc_default_mfa_context(); const int is_mfa_gemv = !is_batched && ((a_rows == 1 && is_transpose_w && (w_rows % 4) == 0) || (!is_transpose_a && w_cols == 1 && (a_cols % 4) == 0)); - // v1 only supports the same precision of accumulator as the tensor. - int is_different_accumulator_precision = ((cmd.info.blas.flags & CCV_NNC_GEMM_32F) && a_datatype == CCV_16F) || ((cmd.info.blas.flags & CCV_NNC_GEMM_16F) && a_datatype == CCV_32F); + int is_upcast = ((cmd.info.blas.flags & CCV_NNC_GEMM_32F) && a_datatype == CCV_16F); const int is_mfa_supported = - ccv_nnc_mfa_context_supported(context) && is_contiguous && is_same_dtype && is_supported_dtype && (!is_batched || is_mfa_compatible_batch) && !(ccv_nnc_flags() & CCV_NNC_DISABLE_METAL_FLASH_ATTENTION) && (is_mfa_gemv || (!(ccv_nnc_flags() & CCV_NNC_DISABLE_MFA_GEMM) && !is_different_accumulator_precision)); + ccv_nnc_mfa_context_supported(context) && is_contiguous && is_same_dtype && is_supported_dtype && (!is_batched || is_mfa_compatible_batch) && !(ccv_nnc_flags() & CCV_NNC_DISABLE_METAL_FLASH_ATTENTION) && (is_mfa_gemv || !(ccv_nnc_flags() & CCV_NNC_DISABLE_MFA_GEMM)); size_t a_data_size = 0; if (CCV_GET_DATA_TYPE(a->info.datatype) == CCV_QX) @@ -364,11 +363,9 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint .A_trans = (is_transpose_a ? 1 : 0), .B_trans = (is_transpose_w ? 1 : 0), .D_trans = 0, - .alpha = (float)1.0, - .beta = (float)0.0, .batched = is_batched, - .fused_activation_function = 0, .fused_bias = (bias ? 1 : 0), + .register_float = (is_upcast ? 1 : 0), .batch_dims_a = { 0 }, .batch_dims_b = { 0 }, @@ -795,10 +792,7 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint .A_trans = 1, .B_trans = (is_transpose_w ? 1 : 0), .D_trans = 0, - .alpha = (float)1.0, - .beta = (float)0.0, .batched = is_batched, - .fused_activation_function = 0, .fused_bias = 0, .batch_dims_a = { 0 }, @@ -834,10 +828,7 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint .A_trans = 0, .B_trans = (is_transpose_w ? 0 : 1), .D_trans = 0, - .alpha = (float)1.0, - .beta = (float)0.0, .batched = is_batched, - .fused_activation_function = 0, .fused_bias = 0, .batch_dims_a = { 0 }, @@ -881,10 +872,7 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint .A_trans = 1, .B_trans = (is_transpose_a ? 1 : 0), .D_trans = 0, - .alpha = (float)1.0, - .beta = (float)0.0, .batched = is_batched, - .fused_activation_function = 0, .fused_bias = 0, .batch_dims_a = { 0 }, @@ -920,10 +908,7 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint .A_trans = (is_transpose_a ? 0 : 1), .B_trans = 0, .D_trans = 0, - .alpha = (float)1.0, - .beta = (float)0.0, .batched = is_batched, - .fused_activation_function = 0, .fused_bias = 0, .batch_dims_a = { 0 }, diff --git a/lib/nnc/cmd/convolution/mps/ccv_nnc_conv_mps.m b/lib/nnc/cmd/convolution/mps/ccv_nnc_conv_mps.m index 114667feb..94bacfd3b 100644 --- a/lib/nnc/cmd/convolution/mps/ccv_nnc_conv_mps.m +++ b/lib/nnc/cmd/convolution/mps/ccv_nnc_conv_mps.m @@ -256,10 +256,7 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint .A_trans = 0, .B_trans = 1, .D_trans = 0, - .alpha = (float)1.0, - .beta = (float)0.0, .batched = is_batched, - .fused_activation_function = 0, .fused_bias = (bias ? 1 : 0), .batch_dims_a = { 0 }, @@ -275,10 +272,7 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint .A_trans = 0, .B_trans = 0, .D_trans = 1, - .alpha = (float)1.0, - .beta = (float)0.0, .batched = is_batched, - .fused_activation_function = 0, .fused_bias = (bias ? 1 : 0), .batch_dims_a = { 0 }, diff --git a/lib/nnc/cmd/scaled_dot_product_attention/mps/ccv_nnc_scaled_dot_product_attention_mps.m b/lib/nnc/cmd/scaled_dot_product_attention/mps/ccv_nnc_scaled_dot_product_attention_mps.m index d49ef0b62..f7436938c 100644 --- a/lib/nnc/cmd/scaled_dot_product_attention/mps/ccv_nnc_scaled_dot_product_attention_mps.m +++ b/lib/nnc/cmd/scaled_dot_product_attention/mps/ccv_nnc_scaled_dot_product_attention_mps.m @@ -316,10 +316,7 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c .A_trans = false, .B_trans = true, .D_trans = false, - .alpha = (float)1.0, - .beta = (float)0.0, .batched = 0, - .fused_activation_function = 0, .fused_bias = (bias ? 1 : 0), .batch_dims_a = { 0 }, diff --git a/lib/nnc/mfa/ccv_nnc_mfa.cpp b/lib/nnc/mfa/ccv_nnc_mfa.cpp index d357156e1..5912ef654 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa.cpp +++ b/lib/nnc/mfa/ccv_nnc_mfa.cpp @@ -10,6 +10,10 @@ mfa::context* ccv_nnc_init_mfa_context(MTL::Device* device) { return new mfa::context(device); } +void ccv_nnc_mfa_clear_pipeline_cache(ccv_nnc_mfa_context_t* context) { + context->v2_cache.evict(); +} + void ccv_nnc_deinit_mfa_context(mfa::context* context) { delete context; } @@ -86,12 +90,6 @@ void mfa::cache::prepare(mfa::co _mfa_cache_prepare(&map, context, hash); } -template <> -void mfa::cache::prepare(mfa::context* context, mfa::gemm::hash hash) -{ - _mfa_cache_prepare(&map, context, hash); -} - template <> void mfa::cache::prepare(mfa::context* context, mfa::normalization::hash hash) { diff --git a/lib/nnc/mfa/ccv_nnc_mfa.hpp b/lib/nnc/mfa/ccv_nnc_mfa.hpp index f39de331f..1ba6f733a 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa.hpp +++ b/lib/nnc/mfa/ccv_nnc_mfa.hpp @@ -4,11 +4,11 @@ #include "nnc/ccv_nnc.h" #include "ccv_nnc_mfa_defines.hpp" #include "ccv_nnc_mfa_attention.hpp" -#include "ccv_nnc_mfa_gemm.hpp" #include "ccv_nnc_mfa_normalization.hpp" #include "ccv_nnc_mfa_depalettize.hpp" #include "ccv_nnc_mfa_adam.hpp" #include "ccv_nnc_mfa_cmul.hpp" +#include "ccv_nnc_mfa_gemm.hpp" #include "ccv_nnc_mfa_gemv.hpp" #include "ccv_nnc_mfa_cast.hpp" #include "ccv_nnc_mfa_add.hpp" @@ -17,6 +17,7 @@ #include "nnc/mfa/3rdparty/metal-cpp/Dispatch.hpp" #include "nnc/mfa/3rdparty/metal-cpp/Metal.hpp" #include "ccv_nnc_mfa_error.hpp" +#include "v2/ShaderCache.hpp" #include namespace ccv { @@ -48,7 +49,6 @@ class context { context(MTL::Device* device); cache attention_cache; - cache gemm_cache; cache normalization_cache; cache depalettize_cache; cache adam_cache; @@ -56,6 +56,8 @@ class context { cache gemv_cache; cache cast_cache; cache add_cache; + + ShaderCache v2_cache; MTL::Buffer* request_scratch(uint64_t size); }; @@ -68,6 +70,7 @@ extern "C" { #endif // __cplusplus ccv_nnc_mfa_context_t* ccv_nnc_init_mfa_context(mtl_device_t* context); +void ccv_nnc_mfa_clear_pipeline_cache(ccv_nnc_mfa_context_t* context); void ccv_nnc_deinit_mfa_context(ccv_nnc_mfa_context_t* context); uint8_t ccv_nnc_mfa_context_supported(ccv_nnc_mfa_context_t* context); uint16_t ccv_nnc_mfa_context_log_level(ccv_nnc_mfa_context_t* context); diff --git a/lib/nnc/mfa/ccv_nnc_mfa_gemm.cpp b/lib/nnc/mfa/ccv_nnc_mfa_gemm.cpp index fc2cccd96..8142aff5b 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa_gemm.cpp +++ b/lib/nnc/mfa/ccv_nnc_mfa_gemm.cpp @@ -3,48 +3,65 @@ #include using namespace ccv::nnc; +#include "v2/ShaderCache.hpp" +#include "v2/GEMMKernel.hpp" +#include "v2/GEMMKernelDescriptor.hpp" +#include "v2/GEMMDescriptor.hpp" #include // MARK: - C void ccv_nnc_mfa_prepare_gemm(mfa::context* context, ccv_nnc_mfa_gemm_params_t params) { - context->gemm_cache.prepare(context, mfa::gemm::hash(params)); + // No-op. } void ccv_nnc_mfa_encode_gemm(mfa::context* context, ccv_nnc_mfa_gemm_params_t params, MTL::CommandBatch* command_batch, MTL::Buffer** tensors, size_t* tensor_offsets) { - mfa::gemm::hash hash(params); - auto iterator = context->gemm_cache.map.find(hash); - if (iterator == context->gemm_cache.map.end()) { - mfa::precondition_failure("GEMM hash not cached.", __LINE__, __FILE__, __FUNCTION__); - } - - auto* pipeline = iterator->second; - auto encoder = command_batch->startCommand(); - encoder->setComputePipelineState(pipeline->pso.get()); - encoder->setThreadgroupMemoryLength(pipeline->threadgroup_memory_length, 0); - int num_tensors = 0; while (tensors[num_tensors] != nullptr) { num_tensors += 1; } CCV_NNC_MFA_PRECONDITION((num_tensors == 3) || (num_tensors == 4)) - - encoder->useResource(tensors[0], MTL::ResourceUsageRead); - encoder->useResource(tensors[1], MTL::ResourceUsageRead); - encoder->useResource(tensors[2], MTL::ResourceUsageWrite); - if (num_tensors >= 4) { - encoder->useResource(tensors[3], MTL::ResourceUsageRead); - } - for (int i = 0; i < num_tensors; ++i) { - encoder->setBuffer(tensors[i], tensor_offsets[i], i); + + // Branch on whether to use the new kernel. + GEMMDescriptor gemmDesc; + gemmDesc.matrixDimensions = simd::uint3 { + params.M, + params.N, + params.K, + }; + switch (params.data_type) { + case MTL::DataTypeHalf: { + gemmDesc.memoryPrecisions = { + .A = GEMMOperandPrecision::FP16, + .B = GEMMOperandPrecision::FP16, + .C = GEMMOperandPrecision::FP16, + .bias = GEMMOperandPrecision::FP16, + }; + break; + } + case MTL::DataTypeFloat: { + gemmDesc.memoryPrecisions = { + .A = GEMMOperandPrecision::FP32, + .B = GEMMOperandPrecision::FP32, + .C = GEMMOperandPrecision::FP32, + .bias = GEMMOperandPrecision::FP32, + }; + break; + } + default: + CCV_NNC_MFA_PRECONDITION(false); + break; } - - // Simple broadcasting rules; not yet support for NumPy broadcasting rules. - simd::ushort4 num_batch_dims(0); - simd::ulong4 batch_sizes(1); + gemmDesc.transposeState = simd::uchar3 { params.A_trans, params.B_trans, params.D_trans }; + gemmDesc.registerPrecisionC = (params.register_float) ? std::optional(GEMMOperandPrecision::FP32) : std::nullopt; + gemmDesc.leadingDimensions = std::nullopt; + gemmDesc.loadPreviousC = false; + gemmDesc.useBias = params.fused_bias; if (params.batched) { + simd::ushort4 num_batch_dims(0); + simd::ulong4 batch_sizes(1); for (uint16_t operand = 0; operand < 4; ++operand) { uint32_t* batch_dims; if (operand == 0) { @@ -56,12 +73,12 @@ void ccv_nnc_mfa_encode_gemm(mfa::context* context, ccv_nnc_mfa_gemm_params_t pa continue; } else if (operand == 3) { // Skip the D operand if unavailable. - if (!(params.fused_activation_function || params.fused_bias)) { + if (!params.fused_bias) { continue; } batch_dims = params.batch_dims_d; } - + for (int i = 0; i < CCV_NNC_MAX_DIM_ALLOC; ++i) { if (batch_dims[i] == 0) { break; @@ -71,250 +88,78 @@ void ccv_nnc_mfa_encode_gemm(mfa::context* context, ccv_nnc_mfa_gemm_params_t pa } } - uint16_t data_type_size = 0; - switch (params.data_type) { - case MTL::DataTypeHalf: { - data_type_size = 2; - break; - } - case MTL::DataTypeFloat: { - data_type_size = 4; - break; - } - default: - CCV_NNC_MFA_PRECONDITION(false); - break; - } - uint64_t byte_stride_a = hash.M * hash.K * data_type_size; - uint64_t byte_stride_b = hash.K * hash.N * data_type_size; - uint64_t byte_stride_c = hash.M * hash.N * data_type_size; - uint64_t byte_stride_d = (hash.D_trans ? hash.M : hash.N) * data_type_size; + uint32_t stride_a = params.M * params.K; + uint32_t stride_b = params.K * params.N; + uint32_t stride_c = params.M * params.N; + uint32_t stride_d = params.D_trans ? params.M : params.N; if (batch_sizes[0] == 1) { - byte_stride_a = 0; + stride_a = 0; } if (batch_sizes[1] == 1) { - byte_stride_b = 0; + stride_b = 0; } if (batch_sizes[3] == 1) { - byte_stride_d = 0; + stride_d = 0; } const unsigned long batch_size = std::max(batch_sizes[0], batch_sizes[1]); - simd::ulong4 matrix_offsets[batch_size]; - for (int i = 0; i < batch_size; ++i) { - matrix_offsets[i] = simd::ulong4 { - i * byte_stride_a, - i * byte_stride_b, - i * byte_stride_c, - i * byte_stride_d, - }; - } - if (batch_size * 32 > 4096) { - auto buffer = context->device->newBuffer(matrix_offsets, batch_size * 32, MTL::ResourceStorageModeShared); - encoder->useResource(buffer, MTL::ResourceUsageRead); - encoder->setBuffer(buffer, 0, 10); - buffer->release(); - } else { - encoder->setBytes(matrix_offsets, batch_size * 32, 10); - } + gemmDesc.batchDimension = batch_size; + simd::uint4 batchStrides; + batchStrides[0] = stride_a; + batchStrides[1] = stride_b; + batchStrides[2] = stride_c; + batchStrides[3] = stride_d; + gemmDesc.batchStrides = batchStrides; + } else { + gemmDesc.batchDimension = 1; + gemmDesc.batchStrides = std::nullopt; } - - auto grid_size = pipeline->grid_size; - grid_size.depth = batch_sizes[0]; - encoder->dispatchThreadgroups(grid_size, pipeline->group_size); - command_batch->finishCommand(encoder); -} - -// MARK: - C++ - -mfa::gemm::hash::hash(ccv_nnc_mfa_gemm_params_t params) { - data_type = params.data_type; - M = params.M; - N = params.N; - K = params.K; - A_trans = params.A_trans; - B_trans = params.B_trans; - D_trans = params.D_trans; - alpha = params.alpha; - beta = params.beta; - batched = params.batched; - fused_activation_function = params.fused_activation_function; - fused_bias = params.fused_bias; -} - -bool mfa::gemm::hash::operator==(const mfa::gemm::hash& hash) const { - return - (data_type == hash.data_type) && - (M == hash.M) && - (N == hash.N) && - (K == hash.K) && - (A_trans == hash.A_trans) && - (B_trans == hash.B_trans) && - (D_trans == hash.D_trans) && - (alpha == hash.alpha) && - (beta == hash.beta) && - (batched == hash.batched) && - (fused_activation_function == hash.fused_activation_function) && - (fused_bias == hash.fused_bias); -} -std::ostream& operator<<(std::ostream& os, const mfa::gemm::hash& hash) { - os << "mfa::gemm::hash {"; - os << " .data_type = " << hash.data_type << ','; - os << " .M = " << hash.M << ','; - os << " .N = " << hash.N << ','; - os << " .K = " << hash.K << ','; - os << " .A_trans = " << bool(hash.A_trans) << ','; - os << " .B_trans = " << bool(hash.B_trans) << ','; - os << " .D_trans = " << bool(hash.D_trans) << ','; - os << " .alpha = " << double(hash.alpha) << ','; - os << " .beta = " << double(hash.beta) << ','; - os << " .batched = " << bool(hash.batched) << ','; - os << " .fused_activation_function = " << bool(hash.fused_activation_function) << ','; - os << " .fused_bias = " << bool(hash.fused_bias) << " "; - os << "}"; - return os; -} + // Instantiate the kernel. + // + // TODO: Remove the autoreleasepool, once you confirm the caller always + // makes one. Or find a different solution, like spawning a pool inside + // of 'fetchKernel' when a new kernel variant is compiled. + auto pool = NS::AutoreleasePool::alloc()->init(); + auto &shaderCache = context->v2_cache; + DeviceProperties dprops = DeviceProperties(); + auto pipelineValue = shaderCache.findKernel(gemmDesc, context->device.get(), dprops); + pool->drain(); + auto kernel = pipelineValue->kernel; + auto pipeline = pipelineValue->pipeline; -std::size_t std::hash::operator()(const mfa::gemm::hash& hash) const noexcept { - std::size_t seed = 0; - using namespace mfa::hash; - combine_64(seed, hash.data_type); - combine_64(seed, pack_64(simd::uint2 { hash.M, hash.N })); - combine_64(seed, pack_64(simd::uint2 { hash.K, pack_32(simd::uchar4 { hash.A_trans, hash.B_trans, hash.D_trans, 0 }) })); - combine_64(seed, pack_64(simd::uint2 { *reinterpret_cast(&hash.alpha), *reinterpret_cast(&hash.beta) })); - combine_32(seed, pack_32(simd::uchar4 { hash.batched, hash.fused_activation_function, hash.fused_bias, 0 })); - return seed; -} + // Allocate a new command. + auto encoder = command_batch->startCommand(); + encoder->setComputePipelineState(pipeline.get()); + encoder->setThreadgroupMemoryLength(kernel->threadgroupMemoryAllocation, 0); -mfa::gemm::pipeline::pipeline(mfa::context* context, mfa::gemm::hash hash) { - CCV_NNC_MFA_PRECONDITION((hash.data_type == MTL::DataTypeFloat) || (hash.data_type == MTL::DataTypeHalf)) - CCV_NNC_MFA_PRECONDITION(hash.alpha == 1.0) - CCV_NNC_MFA_PRECONDITION(hash.beta == 0.0) - CCV_NNC_MFA_PRECONDITION(hash.fused_activation_function == false) - - auto* pool = NS::AutoreleasePool::alloc()->init(); - - auto constants = NS::TransferPtr(MTL::FunctionConstantValues::alloc()->init()); - constants->setConstantValue(&hash.M, MTL::DataTypeUInt, NS::UInteger(0)); - constants->setConstantValue(&hash.N, MTL::DataTypeUInt, 1); - constants->setConstantValue(&hash.K, MTL::DataTypeUInt, 2); - constants->setConstantValue(&hash.A_trans, MTL::DataTypeBool, 10); - constants->setConstantValue(&hash.B_trans, MTL::DataTypeBool, 11); - constants->setConstantValue(&hash.D_trans, MTL::DataTypeBool, 13); - constants->setConstantValue(&hash.alpha, MTL::DataTypeFloat, 20); - constants->setConstantValue(&hash.beta, MTL::DataTypeFloat, 21); - constants->setConstantValue(&hash.batched, MTL::DataTypeBool, 100); - constants->setConstantValue(&hash.fused_activation_function, MTL::DataTypeBool, 101); - constants->setConstantValue(&hash.fused_bias, MTL::DataTypeBool, 50001); - simd::ulong4 garbage(0); - constants->setConstantValue(&garbage, MTL::DataTypeBool, 102); - constants->setConstantValue(&garbage, MTL::DataTypeBool, 103); - constants->setConstantValue(&garbage, MTL::DataTypeBool, 113); - constants->setConstantValue(&garbage, MTL::DataTypeBool, 50000); - - // Eventually, this may incorporate the batch size. - // BxMxN > 1,000,000 -> 48x48, only if M >= 88 and N >= 88 - // BxMxN > 4,000,000 -> 64x64, only if M >= 120 and N >= 120 - uint64_t C_elements = uint64_t(hash.M) * uint64_t(hash.N); - if (hash.batched) { - C_elements *= 2; - } - int is_half = (hash.data_type == MTL::DataTypeHalf); // SD v1 attention - int is_float = (hash.data_type == MTL::DataTypeFloat); // SD v2 attention - - uint16_t M_group = 32; - uint16_t N_group = 32; - uint16_t K_simd = 32; - if (C_elements > 1000 * 1000) { - M_group = 48; - N_group = 48; - } - - // If K_simd is perfectly equal to matrix K, the compiler can elide a large - // amount of logic in the kernel. - if (hash.K >= 33 && hash.K <= 40) { - K_simd = 40; // 1 * 40 - } else if (is_half && hash.K >= 73 && hash.K <= 80) { - K_simd = 40; // 2 * 40 - } else if (C_elements > 1000 * 1000) { - if (hash.K <= 24) { - K_simd = 24; // 1 * 24 - } else if (hash.K <= 32) { - K_simd = 32; // 1 * 32 - } else if (hash.K <= 48) { - K_simd = 24; - } else if (hash.K <= 64) { - K_simd = 32; - } else if (is_float) { - K_simd = 24; - } - } - - uint16_t M_splits = 2; - uint16_t N_splits = 2; - uint16_t M_simd = M_group / M_splits; - uint16_t N_simd = N_group / N_splits; - - constants->setConstantValue(&M_simd, MTL::DataTypeUShort, 200); - constants->setConstantValue(&N_simd, MTL::DataTypeUShort, 201); - constants->setConstantValue(&K_simd, MTL::DataTypeUShort, 202); - constants->setConstantValue(&M_splits, MTL::DataTypeUShort, 210); - constants->setConstantValue(&N_splits, MTL::DataTypeUShort, 211); - - std::string cpp_name; - uint16_t data_type_size = UINT16_MAX; - switch (hash.data_type) { - case MTL::DataTypeHalf: { - cpp_name = "hgemm"; - data_type_size = 2; - break; - } - case MTL::DataTypeFloat: { - cpp_name = "sgemm"; - data_type_size = 4; - break; - } - default: { - CCV_NNC_MFA_PRECONDITION(false) - break; - } - } - auto* swift_name = NS::String::string(cpp_name.c_str(), NS::UTF8StringEncoding); - - uint16_t A_block_bytes = M_group * K_simd * data_type_size; - uint16_t B_block_bytes = K_simd * N_group * data_type_size; - uint16_t C_block_bytes = M_group * N_group * data_type_size; - threadgroup_memory_length = A_block_bytes + B_block_bytes; - - if ((hash.M % 8 > 0) && (hash.N % 8 > 0)) { - if (C_block_bytes > threadgroup_memory_length) { - threadgroup_memory_length = C_block_bytes; - } + // Bind the function arguments. + encoder->useResource(tensors[0], MTL::ResourceUsageRead); + encoder->useResource(tensors[1], MTL::ResourceUsageRead); + encoder->useResource(tensors[2], MTL::ResourceUsageWrite); + if (num_tensors >= 4) { + encoder->useResource(tensors[3], MTL::ResourceUsageRead); } - if (hash.fused_bias) { - uint16_t D_block_bytes = (hash.D_trans ? M_group : N_group) * data_type_size; - if (D_block_bytes > threadgroup_memory_length) { - threadgroup_memory_length = D_block_bytes; - } + for (int i = 0; i < num_tensors; ++i) { + encoder->setBuffer(tensors[i], tensor_offsets[i], i); } - - std::function ceil_divide = [](size_t original, uint16_t granularity) { - return (original + size_t(granularity) - 1) / size_t(granularity); + + // Calculate the grid size. + auto ceilDivide = + [=](int64_t target, uint16_t granularity) -> int64_t { + return (target + int64_t(granularity) - 1) / int64_t(granularity); }; - grid_size = MTL::Size(ceil_divide(hash.N, N_group), ceil_divide(hash.M, M_group), 1); - group_size = MTL::Size(32 * M_splits * N_splits, 1, 1); - - NS::Error* error = nullptr; - auto function = NS::TransferPtr(context->library->newFunction(swift_name, constants.get(), &error)); - if (!function) { - CCV_NNC_MFA_CHECK_ERROR(error) - } - - pso = NS::TransferPtr(context->device->newComputePipelineState(function.get(), &error)); - if (!pso) { - CCV_NNC_MFA_CHECK_ERROR(error) - } - - pool->drain(); + MTL::Size gridSize + (ceilDivide(int64_t(params.N), kernel->blockDimensions[1]), + ceilDivide(int64_t(params.M), kernel->blockDimensions[0]), + gemmDesc.batchDimension); + MTL::Size groupSize + (int64_t(kernel->threadgroupSize), 1, 1); + + // Dispatch the required number of threads. + encoder->dispatchThreadgroups(gridSize, groupSize); + + // Finish the command. + command_batch->finishCommand(encoder); } + diff --git a/lib/nnc/mfa/ccv_nnc_mfa_gemm.hpp b/lib/nnc/mfa/ccv_nnc_mfa_gemm.hpp index a2f701394..92a6e1f3c 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa_gemm.hpp +++ b/lib/nnc/mfa/ccv_nnc_mfa_gemm.hpp @@ -9,11 +9,9 @@ typedef struct { uint8_t A_trans; uint8_t B_trans; uint8_t D_trans; - float alpha; - float beta; uint8_t batched; - uint8_t fused_activation_function; uint8_t fused_bias; + uint8_t register_float; // Fill these in the same order as the original shape, but null-terminated. // Both arrays must have the same length. @@ -25,56 +23,6 @@ typedef struct { #ifdef __cplusplus #include "nnc/mfa/3rdparty/metal-cpp/Dispatch.hpp" #include "nnc/mfa/3rdparty/metal-cpp/Metal.hpp" -#include - -namespace ccv { -namespace nnc { -namespace mfa { -namespace gemm { - -class hash { -public: - uint64_t data_type; - uint32_t M; - uint32_t N; - uint32_t K; - uint8_t A_trans; - uint8_t B_trans; - uint8_t D_trans; - float alpha; - float beta; - uint8_t batched; - uint8_t fused_activation_function; - uint8_t fused_bias; - - hash(ccv_nnc_mfa_gemm_params_t); - - bool operator==(const hash& rhs) const; -}; - -class pipeline { -public: - NS::SharedPtr pso; - - uint16_t threadgroup_memory_length; - MTL::Size grid_size; - MTL::Size group_size; - - pipeline(context* context, hash hash); -}; - -} // namespace gemm -} // namespace mfa -} // namespace nnc -} // namespace ccv - -std::ostream& operator<<(std::ostream& os, const ccv::nnc::mfa::gemm::hash& hash); - -template<> -struct std::hash -{ - std::size_t operator()(const ccv::nnc::mfa::gemm::hash& hash) const noexcept; -}; extern "C" { #endif // __cplusplus diff --git a/lib/nnc/mfa/ccv_nnc_mfa_hash.hpp b/lib/nnc/mfa/ccv_nnc_mfa_hash.hpp index c8b705b2f..cc513d7dc 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa_hash.hpp +++ b/lib/nnc/mfa/ccv_nnc_mfa_hash.hpp @@ -50,14 +50,26 @@ inline uint32_t pack_32(const simd::uchar4& v) { return reinterpret_cast(v); } +inline uint32_t pack_32(const simd::ushort2& v) { + return reinterpret_cast(v); +} + inline size_t combine_64(std::size_t& seed, const uint64_t& v) { return rotl(seed, std::numeric_limits::digits/3) ^ distribute_64(v); } +inline uint64_t pack_64(const simd::ushort4& v) { + return reinterpret_cast(v); +} + inline uint64_t pack_64(const simd::uint2& v) { return reinterpret_cast(v); } +inline simd::ulong2 pack_128(const simd::ushort8& v) { + return reinterpret_cast(v); +} + } // namespace hash } // namespace mfa } // namespace nnc diff --git a/lib/nnc/mfa/makefile b/lib/nnc/mfa/makefile index 563e64717..06d3c9ad2 100644 --- a/lib/nnc/mfa/makefile +++ b/lib/nnc/mfa/makefile @@ -2,7 +2,7 @@ include ../../config.mk CFLAGS := -std=c++17 -O3 -Wall -I"../../" $(CFLAGS) -SRCS := Metal.cpp ccv_nnc_mfa.cpp ccv_nnc_mfa_attention.cpp ccv_nnc_mfa_error.cpp ccv_nnc_mfa_gemm.cpp ccv_nnc_mfa_normalization.cpp ccv_nnc_mfa_depalettize.cpp ccv_nnc_mfa_adam.cpp ccv_nnc_mfa_cmul.cpp ccv_nnc_mfa_gemv.cpp ccv_nnc_mfa_cast.cpp ccv_nnc_mfa_add.cpp 3rdparty/metal-cpp/Dispatch.cpp +SRCS := Metal.cpp ccv_nnc_mfa.cpp ccv_nnc_mfa_attention.cpp ccv_nnc_mfa_error.cpp ccv_nnc_mfa_gemm.cpp ccv_nnc_mfa_normalization.cpp ccv_nnc_mfa_depalettize.cpp ccv_nnc_mfa_adam.cpp ccv_nnc_mfa_cmul.cpp ccv_nnc_mfa_gemv.cpp ccv_nnc_mfa_cast.cpp ccv_nnc_mfa_add.cpp 3rdparty/metal-cpp/Dispatch.cpp v2/CodeWriter.cpp v2/GEMMDescriptor.cpp v2/GEMMKernelDescriptor.cpp v2/GEMMHeaders.cpp v2/GEMMKernel.cpp SRC_OBJS := $(patsubst %.c,%.o,$(patsubst %.cpp,%.o,$(SRCS))) diff --git a/lib/nnc/mfa/v2/CodeWriter.cpp b/lib/nnc/mfa/v2/CodeWriter.cpp new file mode 100644 index 000000000..26df5156a --- /dev/null +++ b/lib/nnc/mfa/v2/CodeWriter.cpp @@ -0,0 +1,51 @@ +#include "CodeWriter.hpp" + +#include + +void CodeWriter::operator+=(std::string text) { + if (!ignore_ident_ && !text.empty()) AppendIdent(stream_); + + while (true) { + auto begin = text.find("{{"); + if (begin == std::string::npos) { break; } + + auto end = text.find("}}"); + if (end == std::string::npos || end < begin) { break; } + + // Write all the text before the first {{ into the stream. + stream_.write(text.c_str(), begin); + + // The key is between the {{ and }}. + const std::string key = text.substr(begin + 2, end - begin - 2); + + // Find the value associated with the key. If it exists, write the + // value into the stream, otherwise write the key itself into the stream. + auto iter = value_map_.find(key); + if (iter != value_map_.end()) { + const std::string &value = iter->second; + stream_ << value; + } else { + assert(false && "could not find key"); + stream_ << key; + } + + // Update the text to everything after the }}. + text = text.substr(end + 2); + } + if (!text.empty() && text.back() == '\\') { + text.pop_back(); + ignore_ident_ = true; + stream_ << text; + } else { + ignore_ident_ = false; + stream_ << text << std::endl; + } +} + +void CodeWriter::AppendIdent(std::stringstream &stream) { + int lvl = cur_ident_lvl_; + while (lvl--) { + stream.write(pad_.c_str(), static_cast(pad_.size())); + } +} + diff --git a/lib/nnc/mfa/v2/CodeWriter.hpp b/lib/nnc/mfa/v2/CodeWriter.hpp new file mode 100644 index 000000000..f2f114331 --- /dev/null +++ b/lib/nnc/mfa/v2/CodeWriter.hpp @@ -0,0 +1,59 @@ +#ifndef MFA_CODE_WRITER_HPP_ +#define MFA_CODE_WRITER_HPP_ + +#include +#include + +class CodeWriter { + public: + CodeWriter(std::string pad = std::string()) + : pad_(pad), cur_ident_lvl_(0), ignore_ident_(false) {} + + // Clears the current "written" code. + void Clear() { + stream_.str(""); + stream_.clear(); + } + + // Associates a key with a value. All subsequent calls to operator+=, where + // the specified key is contained in {{ and }} delimiters will be replaced by + // the given value. + void SetValue(const std::string &key, const std::string &value) { + value_map_[key] = value; + } + + std::string GetValue(const std::string &key) const { + const auto it = value_map_.find(key); + return it == value_map_.end() ? "" : it->second; + } + + // Appends the given text to the generated code as well as a newline + // character. Any text within {{ and }} delimiters is replaced by values + // previously stored in the CodeWriter by calling SetValue above. The newline + // will be suppressed if the text ends with the \\ character. + void operator+=(std::string text); + + // Returns the current contents of the CodeWriter as a std::string. + std::string ToString() const { return stream_.str(); } + + // Increase ident level for writing code + void IncrementIdentLevel() { cur_ident_lvl_++; } + // Decrease ident level for writing code + void DecrementIdentLevel() { + if (cur_ident_lvl_) cur_ident_lvl_--; + } + + void SetPadding(const std::string &padding) { pad_ = padding; } + + private: + std::map value_map_; + std::stringstream stream_; + std::string pad_; + int cur_ident_lvl_; + bool ignore_ident_; + + // Add ident padding (tab or space) based on ident level + void AppendIdent(std::stringstream &stream); +}; + +#endif diff --git a/lib/nnc/mfa/v2/DeviceProperties.hpp b/lib/nnc/mfa/v2/DeviceProperties.hpp new file mode 100644 index 000000000..6b7489182 --- /dev/null +++ b/lib/nnc/mfa/v2/DeviceProperties.hpp @@ -0,0 +1,8 @@ +#ifndef MFA_DEVICE_PROPERTIES_HPP_ +#define MFA_DEVICE_PROPERTIES_HPP_ + +struct DeviceProperties { + uint32_t coreCount; +}; + +#endif diff --git a/lib/nnc/mfa/v2/GEMMDescriptor.cpp b/lib/nnc/mfa/v2/GEMMDescriptor.cpp new file mode 100644 index 000000000..0b7cc720c --- /dev/null +++ b/lib/nnc/mfa/v2/GEMMDescriptor.cpp @@ -0,0 +1,285 @@ +#include "GEMMDescriptor.hpp" +#include "GEMMKernelDescriptor.hpp" +#include "GEMMKernel.hpp" +#include "../ccv_nnc_mfa_hash.hpp" +#include "../ccv_nnc_mfa_error.hpp" + +bool GEMMDescriptor::operator==(const GEMMDescriptor& rhs) const { + return + (batchDimension == rhs.batchDimension) && + simd_all(matrixDimensions == rhs.matrixDimensions) && + simd_all(leadingDimensions.value_or(simd::uint3(UINT32_MAX)) == rhs.leadingDimensions.value_or(simd::uint3(UINT32_MAX))) && + simd_all(batchStrides.value_or(simd::uint4(UINT32_MAX)) == rhs.batchStrides.value_or(simd::uint4(UINT32_MAX))) && + memoryPrecisions == rhs.memoryPrecisions && + registerPrecisionC == rhs.registerPrecisionC && + simd_all(transposeState == rhs.transposeState) && + (useBias == rhs.useBias); +} + +std::size_t std::hash::operator()(const GEMMDescriptor& hash) const noexcept { + std::size_t seed = 0; + using namespace ccv::nnc::mfa::hash; + combine_64(seed, hash.batchDimension); + combine_32(seed, hash.matrixDimensions[0]); + combine_32(seed, hash.matrixDimensions[1]); + combine_32(seed, hash.matrixDimensions[2]); + if (hash.leadingDimensions.has_value()) { + combine_32(seed, hash.leadingDimensions.value()[0]); + combine_32(seed, hash.leadingDimensions.value()[1]); + combine_32(seed, hash.leadingDimensions.value()[2]); + } + if (hash.batchStrides.has_value()) { + combine_32(seed, hash.batchStrides.value()[0]); + combine_32(seed, hash.batchStrides.value()[1]); + combine_32(seed, hash.batchStrides.value()[2]); + combine_32(seed, hash.batchStrides.value()[3]); + } + combine_64(seed, pack_64(simd::ushort4 { hash.memoryPrecisions.A.value, hash.memoryPrecisions.B.value, hash.memoryPrecisions.C.value, hash.memoryPrecisions.bias.value })); + combine_32(seed, pack_32(simd::uchar4 { hash.transposeState[0], hash.transposeState[1], hash.transposeState[2], 0 })); + combine_32(seed, pack_32(simd::uchar4 { hash.loadPreviousC, hash.useBias, 0, 0 })); + if (hash.registerPrecisionC.has_value()) { + combine_32(seed, pack_32(simd::ushort2 { hash.registerPrecisionC.value().value, 0 })); + } + return seed; +} + +std::pair *> GEMMDescriptor::findKernel(MTL::Device *const device, const DeviceProperties &dprops, std::unordered_map> *const libraryCache) const noexcept { + // The caller is not responsible for calling 'delete' on this pointer. The + // reference is saved in the 'libraryCache'. It will be deallocated whenever + // the shader cache itself is cleaned up. + auto createKernel = + [=](GEMMKernelDescriptor descriptor) -> GEMMKernel* { + auto iterator = libraryCache->find(descriptor); + if (iterator != libraryCache->end()) { + return iterator->second.get(); + } else { + GEMMKernel* kernel = new GEMMKernel(descriptor, device); + (*libraryCache)[descriptor] = std::unique_ptr(kernel); + return kernel; + } + }; + + // WARNING: The owner must explicitly retain the compute pipeline. + auto createPipeline = + [=](MTL::Library* library) -> MTL::ComputePipelineState* { + // Set the function constants. + auto constants = NS::TransferPtr + (MTL::FunctionConstantValues::alloc()->init()); + uint32_t M = this->matrixDimensions[0]; + uint32_t N = this->matrixDimensions[1]; + uint32_t K = this->matrixDimensions[2]; + constants->setConstantValue(&M, MTL::DataTypeUInt, NS::UInteger(0)); + constants->setConstantValue(&N, MTL::DataTypeUInt, 1); + constants->setConstantValue(&K, MTL::DataTypeUInt, 2); + + auto chooseLeadingDimension = + [=](unsigned int specifiedLeading, bool transposeState, unsigned int untransposedRows, unsigned int untransposedColumns) -> unsigned int { + unsigned int expectedLeading; + if (transposeState) { + expectedLeading = untransposedRows; + } else { + expectedLeading = untransposedColumns; + } + + unsigned int actualLeading; + if (specifiedLeading > 0) { + if (specifiedLeading < expectedLeading) { + CCV_NNC_MFA_PRECONDITION(false && "Leading block dimension was too small."); + } + actualLeading = specifiedLeading; + } else { + actualLeading = expectedLeading; + } + + return actualLeading; + }; + + auto leadingDimensionA = chooseLeadingDimension( + leadingDimensions.value_or(simd::uint3())[0], transposeState[0], + matrixDimensions[0], matrixDimensions[2]); + auto leadingDimensionB = chooseLeadingDimension( + leadingDimensions.value_or(simd::uint3())[1], transposeState[1], + matrixDimensions[2], matrixDimensions[1]); + auto leadingDimensionC = chooseLeadingDimension( + leadingDimensions.value_or(simd::uint3())[2], false, + matrixDimensions[0], matrixDimensions[1]); + + constants->setConstantValue(&leadingDimensionA, MTL::DataTypeUInt, 5); + constants->setConstantValue(&leadingDimensionB, MTL::DataTypeUInt, 6); + constants->setConstantValue(&leadingDimensionC, MTL::DataTypeUInt, 7); + + bool loadPreviousC = this->loadPreviousC; + constants->setConstantValue(&loadPreviousC, MTL::DataTypeBool, 10); + + bool batched = this->batchDimension > 1; + constants->setConstantValue(&batched, MTL::DataTypeBool, 11); + simd::uint4 batchStrides = this->batchStrides.value_or(simd::uint4(0)); + auto batchStrideA = batchStrides[0]; + auto batchStrideB = batchStrides[1]; + auto batchStrideC = batchStrides[2]; + auto batchStrideBias = batchStrides[3]; + constants->setConstantValue(&batchStrideA, MTL::DataTypeUInt, 15); + constants->setConstantValue(&batchStrideB, MTL::DataTypeUInt, 16); + constants->setConstantValue(&batchStrideC, MTL::DataTypeUInt, 17); + constants->setConstantValue(&batchStrideBias, MTL::DataTypeUInt, 18); + + NS::String* swiftName = NS::String::string("gemm", NS::UTF8StringEncoding); + NS::Error* error = nil; + + auto function = NS::TransferPtr + (library->newFunction(swiftName, constants.get(), &error)); + CCV_NNC_MFA_CHECK_ERROR(error); + + auto pipeline = device->newComputePipelineState(function.get(), &error); + CCV_NNC_MFA_CHECK_ERROR(error); + return pipeline; + }; + + GEMMOperandPrecision registerPrecisionA = memoryPrecisions.A; + GEMMOperandPrecision registerPrecisionB = memoryPrecisions.B; + GEMMOperandPrecision registerPrecisionBias = memoryPrecisions.bias; + GEMMOperandPrecision registerPrecisionC = this->registerPrecisionC.value_or(GEMMOperandPrecision::FP32); + if (!this->registerPrecisionC.has_value() && + memoryPrecisions.A == GEMMOperandPrecision::FP16 && + memoryPrecisions.B == GEMMOperandPrecision::FP16 && + memoryPrecisions.C == GEMMOperandPrecision::FP16) { + // If FP16 is causing accuracy issues, you can change this to FP32. Note + // that doing so cuts out a very important part of the performance + // spectrum. It is only FP16xFP16->FP16 that reaches peak performance. + // This statement applies to both the M1 and M3 architectures. + // + // FP16xFP16 into FP16 accumulator triggers this instruction: + // https://github.com/dougallj/applegpu/blob/aeb81519159246d70c56d3f77adb4bc9cca7aa0d/applegpu.py#L3232-L3244 + // + // FP16xFP16/BF16xBF16 into FP32 accumulator triggers this instruction: + // https://github.com/dougallj/applegpu/blob/aeb81519159246d70c56d3f77adb4bc9cca7aa0d/applegpu.py#L3195-L3207 + // + // No other input/output register types map to a native instruction. + // + // I would recommend changing the accumulator precision on a case-by-case + // (operation-by-operation) basis. Provide some mechanism in the high-level + // API, to control certain low-level features. Without harming execution + // latency and without imposing technical debt on the high-level API. + // Definitely NOT a global flag that forces all matrices to change from + // FP16 -> FP32. + registerPrecisionC = GEMMOperandPrecision::FP16; + } + + // Set the device and examine the block dimensions. + auto blockDimensionsAndPaddedBlockDimensions = GEMMKernelDescriptor::getBlockDimensions(device, dprops.coreCount, this->matrixDimensions, this->batchDimension, this->memoryPrecisions, this->transposeState); + std::optional preferAsyncStore = std::nullopt; + bool preferAsyncLoad; + simd::ushort2 splits; + if (device->supportsFamily(MTL::GPUFamily(1009))) { + preferAsyncLoad = false; + preferAsyncStore = false; + splits = { 1, 1 }; + } else { + // For device without native BF16 support, use register at FP32. + if (memoryPrecisions.A == GEMMOperandPrecision::BF16) { + registerPrecisionA = GEMMOperandPrecision::FP32; + } + if (memoryPrecisions.B == GEMMOperandPrecision::BF16) { + registerPrecisionB = GEMMOperandPrecision::FP32; + } + preferAsyncLoad = true; + if (simd_all(blockDimensionsAndPaddedBlockDimensions.first == simd::ushort3 { 48, 48, 32 })) { + preferAsyncStore.reset(); + } else { + preferAsyncStore = true; + } + splits = { 2, 2 }; + } + const GEMMOperandPrecisions registerPrecisions = { + .A = registerPrecisionA, + .B = registerPrecisionB, + .C = registerPrecisionC, + .bias = registerPrecisionBias, + }; + + // Run a combinatorial search to find the correct value for + // 'preferAsyncStore'. + if (preferAsyncStore.has_value()) { + auto kernelDesc = GEMMKernelDescriptor(blockDimensionsAndPaddedBlockDimensions.first, this->memoryPrecisions, blockDimensionsAndPaddedBlockDimensions.second, preferAsyncLoad, preferAsyncStore.value(), registerPrecisions, splits, this->transposeState, this->useBias); + GEMMKernel* kernel = createKernel(kernelDesc); + auto pipeline = NS::TransferPtr(createPipeline(kernel->library.get())); + + // Force the user to retrieve the return value from the cache. We ensure + // the cache takes ownership, and the pointer doesn't become a zombie + // object. + PipelineValue* output = new PipelineValue { kernel, pipeline }; + return std::make_pair(kernelDesc, output); + } else { + auto kernelDesc = GEMMKernelDescriptor(blockDimensionsAndPaddedBlockDimensions.first, this->memoryPrecisions, blockDimensionsAndPaddedBlockDimensions.second, preferAsyncLoad, false, registerPrecisions, splits, this->transposeState, this->useBias); + struct Candidate { + GEMMKernelDescriptor kernelDesc; + GEMMKernel* kernel; + NS::SharedPtr pipeline; + }; + std::vector candidates; + + for (int8_t candidateID = 0; candidateID < 4; ++candidateID) { + simd::ushort3 blockDimensions; + if (candidateID % 2 == 0) { + blockDimensions = simd::ushort3 { 48, 48, 32 }; + } else { + blockDimensions = simd::ushort3 { 48, 48, 40 }; + } + + bool preferAsyncStore; + if (candidateID / 2 == 0) { + preferAsyncStore = false; + } else { + preferAsyncStore = true; + } + + // Set the data that's unique to this variant. + auto newKernelDesc = kernelDesc; + newKernelDesc.blockDimensions = blockDimensions; + newKernelDesc.preferAsyncStore = preferAsyncStore; + + GEMMKernel* kernel = createKernel(newKernelDesc); + auto pipeline = NS::TransferPtr + (createPipeline(kernel->library.get())); + + Candidate candidate { + .kernelDesc = newKernelDesc, + .kernel = kernel, + .pipeline = pipeline + }; + candidates.push_back(candidate); + } + + // Find the maximum occupancy. + int64_t maximumOccupancy = -1; + for (Candidate candidate : candidates) { + int64_t occupancy = candidate.pipeline->maxTotalThreadsPerThreadgroup(); + maximumOccupancy = std::max(maximumOccupancy, occupancy); + } + + // Remove all candidates that don't match this occupancy. + { + std::vector newCandidates; + for (Candidate candidate : candidates) { + int64_t occupancy = candidate.pipeline->maxTotalThreadsPerThreadgroup(); + if (occupancy != maximumOccupancy) { + continue; + } + newCandidates.push_back(candidate); + } + candidates = newCandidates; + } + + // Choose the highest-performing candidate. + Candidate candidate = candidates[candidates.size() - 1]; + kernelDesc = candidate.kernelDesc; + + // Force the user to retrieve the return value from the cache. We ensure + // the cache takes ownership, and the pointer doesn't become a zombie + // object. + PipelineValue* output = new PipelineValue { + candidate.kernel, candidate.pipeline + }; + return std::make_pair(candidate.kernelDesc, output); + } +} diff --git a/lib/nnc/mfa/v2/GEMMDescriptor.hpp b/lib/nnc/mfa/v2/GEMMDescriptor.hpp new file mode 100644 index 000000000..f0a5134c9 --- /dev/null +++ b/lib/nnc/mfa/v2/GEMMDescriptor.hpp @@ -0,0 +1,57 @@ +#ifndef MFA_GEMMDESCRIPTOR_HPP_ +#define MFA_GEMMDESCRIPTOR_HPP_ + +#include +#include +#include "PipelineValue.hpp" +#include "DeviceProperties.hpp" +#include "GEMMOperandPrecision.hpp" + +struct GEMMKernelDescriptor; +struct GEMMKernel; + +struct GEMMDescriptor { + /// The number of equally sized multiplications that run in parallel. + int64_t batchDimension = 1; + + /// The dimensions of the input and output matrices. + /// - Parameter M: Number of output columns. + /// - Parameter N: Number of output rows. + /// - Parameter K: Number of loop iterations for the dot products. + /// + /// For all practical purposes, one can assume matrix dimensions are 32-bit. + /// I use this quite often in other code. The pointers themselves are 64-bit, + /// but the offsets between different elements are 32-bit. With 4-byte words, + /// this scheme could access up to 16 GB of memory - larger than any array + /// in any reasonable application. Handling larger allocations likely + /// requires consideration of more failure points than just integer + /// overflows. + simd::uint3 matrixDimensions; + + GEMMOperandPrecisions memoryPrecisions; + + std::optional registerPrecisionC; + + std::optional leadingDimensions; + + std::optional batchStrides; + + simd::uchar3 transposeState; + + bool loadPreviousC; + + bool useBias; + + bool operator==(const GEMMDescriptor& rhs) const; + + std::pair *> findKernel(MTL::Device* const device, const DeviceProperties &dprops, std::unordered_map> *const libraryCache) const noexcept; +}; + +template<> +struct std::hash +{ + std::size_t operator()(const GEMMDescriptor& hash) const noexcept; +}; + +#endif + diff --git a/lib/nnc/mfa/v2/GEMMHeaders.cpp b/lib/nnc/mfa/v2/GEMMHeaders.cpp new file mode 100644 index 000000000..68a66b3b2 --- /dev/null +++ b/lib/nnc/mfa/v2/GEMMHeaders.cpp @@ -0,0 +1,671 @@ +#include "../ccv_nnc_mfa.hpp" + +#include +#include + +std::string createMetalSimdgroupEvent() { + // Return the source string. + return R"( +// -*- Metal -*- +//===-- metal_simdgroup_event ---------------------------------------------===// +// Copyright (c) 2024 Philip Turner. See MIT LICENSE +//===----------------------------------------------------------------------===// + +#ifndef __METAL_SIMDGROUP_EVENT +#define __METAL_SIMDGROUP_EVENT + +// Invoking the generation of LLVM bitcode for async copies. +// +// %struct._simdgroup_event_t = type opaque +// +struct _simdgroup_event_t; + +// Invoking the generation of LLVM bitcode for async copies. +// +// Bitcode: TBD +// +thread _simdgroup_event_t* +__metal_simdgroup_async_copy_1d( + ulong, ulong, threadgroup void *, const device void *, ulong) + __asm("air.simdgroup_async_copy_1d.p3i8.p1i8"); + +// Invoking the generation of LLVM bitcode for async copies. +// +// Bitcode: TBD +// +thread _simdgroup_event_t* +__metal_simdgroup_async_copy_1d( + ulong, ulong, device void *, const threadgroup void *, ulong) + __asm("air.simdgroup_async_copy_1d.p1i8.p3i8"); + +// Invoking the generation of LLVM bitcode for async copies. +// +// ; Function Attrs: argmemonly convergent nounwind +// declare %struct._simdgroup_event_t* +// @air.simdgroup_async_copy_2d.p3i8.p1i8( +// i64, i64, +// i8 addrspace(3)* nocapture writeonly, i64, i64, <2 x i64>, +// i8 addrspace(1)* nocapture readonly, i64, i64, <2 x i64>, +// <2 x i64>, i32) +// local_unnamed_addr #4 +// +thread _simdgroup_event_t* +__metal_simdgroup_async_copy_2d( + ulong, ulong, + threadgroup void *, ulong, ulong, ulong2, + const device void *, ulong, ulong, ulong2, + long2, int) + __asm("air.simdgroup_async_copy_2d.p3i8.p1i8"); + +// Invoking the generation of LLVM bitcode for async copies. +// +// ; Function Attrs: argmemonly convergent nounwind +// declare %struct._simdgroup_event_t* +// @air.simdgroup_async_copy_2d.p1i8.p3i8( +// i64, i64, +// i8 addrspace(1)* nocapture writeonly, i64, i64, <2 x i64>, +// i8 addrspace(3)* nocapture readonly, i64, i64, <2 x i64>, +// <2 x i64>, i32) +// local_unnamed_addr #4 +// +thread _simdgroup_event_t* +__metal_simdgroup_async_copy_2d( + ulong, ulong, + device void *, ulong, ulong, ulong2, + const threadgroup void *, ulong, ulong, ulong2, + long2, int) + __asm("air.simdgroup_async_copy_2d.p1i8.p3i8"); + +// Invoking the generation of LLVM bitcode for async copies. +// +// ; Function Attrs: convergent nounwind +// declare void +// @air.wait_simdgroup_events(i32, %struct._simdgroup_event_t** nocapture) +// local_unnamed_addr #3 +// +void __metal_wait_simdgroup_events( + int, thread _simdgroup_event_t**) + __asm("air.wait_simdgroup_events"); + +#pragma METAL internals : enable +namespace metal +{ + enum class simdgroup_async_copy_clamp_mode { + clamp_to_zero = 0, + clamp_to_edge = 1 + }; + + struct simdgroup_event { + METAL_FUNC simdgroup_event() thread {} + + template + METAL_FUNC void async_copy( + threadgroup T *dst, + const device T *src, + ulong n_elements + ) thread { + event = __metal_simdgroup_async_copy_1d( + // Description of the data type. + sizeof(T), + alignof(T), + + // Description of the arguments. + reinterpret_cast(dst), + reinterpret_cast(src), + n_elements); + } + + template + METAL_FUNC void async_copy( + device T *dst, + const threadgroup T *src, + ulong n_elements + ) thread { + event = __metal_simdgroup_async_copy_1d( + // Description of the data type. + sizeof(T), + alignof(T), + + // Description of the arguments. + reinterpret_cast(dst), + reinterpret_cast(src), + n_elements); + } + + template + METAL_FUNC void async_copy( + // Description of the destination. + threadgroup T *dst, + ushort dst_elements_per_row, + ushort2 dst_tile_dimensions, + + // Description of the source. + const device T *src, + uint src_elements_per_row, + ushort2 src_tile_dimensions, + + // Other arguments. + bool transpose_matrix = false, + simdgroup_async_copy_clamp_mode clamp_mode = + simdgroup_async_copy_clamp_mode::clamp_to_zero + ) thread { + if (transpose_matrix) { + src_tile_dimensions = src_tile_dimensions.yx; + dst_tile_dimensions = dst_tile_dimensions.yx; + } + event = __metal_simdgroup_async_copy_2d( + // Description of the data type. + sizeof(T), + alignof(T), + + // Description of the destination. + reinterpret_cast(dst), + ushort(dst_elements_per_row), + 1, + ulong2(dst_tile_dimensions), + + // Description of the source. + reinterpret_cast(src), + uint(src_elements_per_row), + 1, + ulong2(src_tile_dimensions), + + // Other arguments. + long2(0), + static_cast(clamp_mode)); + } + + template + METAL_FUNC void async_copy( + // Description of the destination. + device T *dst, + uint dst_elements_per_row, + ushort2 dst_tile_dimensions, + + // Description of the source. + const threadgroup T *src, + ushort src_elements_per_row, + ushort2 src_tile_dimensions, + + // Other arguments. + bool transpose_matrix = false + ) thread { + if (transpose_matrix) { + src_tile_dimensions = src_tile_dimensions.yx; + dst_tile_dimensions = dst_tile_dimensions.yx; + } + event = __metal_simdgroup_async_copy_2d( + // Description of the data type. + sizeof(T), + alignof(T), + + // Description of the destination. + reinterpret_cast(dst), + uint(dst_elements_per_row), + 1, + ulong2(dst_tile_dimensions), + + // Description of the source. + reinterpret_cast(src), + ushort(src_elements_per_row), + 1, + ulong2(src_tile_dimensions), + + // Other arguments. + long2(0), + 0); + } + + METAL_FUNC static void wait(int count, thread simdgroup_event *events) { + __metal_wait_simdgroup_events( + count, reinterpret_cast(events)); + } + + private: + // Invoking the generation of LLVM bitcode for async copies. + // + // %"struct.metal::simdgroup_event" = type { %struct._simdgroup_event_t* } + // + thread _simdgroup_event_t* event; + }; +} // namespace metal +#pragma METAL internals : disable + +#endif // __METAL_SIMDGROUP_EVENT +)"; +} + +std::string createMetalSimdgroupMatrixStorage(bool BF16) { + // How this header spawning code was designed. + // + // Find the patterns between the load/store functions: + // - device has 'uint' elements_per_row + // - threadgroup has 'ushort' elements_per_row + // - both have 'ushort2' matrix_origin + // + // The origin is 'ushort2' because the 32-bit part of the address should have + // been applied previously during 'apply_offset'. The 16-bit part should be + // hard-coded into the assembly when the GEMM loop is unrolled. + // + // Transpose path: + // - load: reads two values; should split each one onto a separate line. + // - overwrites the value of *thread_elements() with a new vec + // - store: the two instructions are on two separate lines. + // - fetches from lane 0 or 1 of thread_elements()[0] + // - adds 0 or 1 to the hard-coded matrix_origin.x + // + // Address generation: + // - casts some intermediate address fragments to 'ulong' for 'device' + // - keeps all address fragments in 'ushort' for 'threadgroup' + + enum class AddressSpace { + device, + threadgroup, + }; + + auto keyword = + [=](AddressSpace value) -> std::string { + switch (value) { + case AddressSpace::device: + return "device"; + case AddressSpace::threadgroup: + return "threadgroup"; + } + }; + + auto offsetType = + [=](AddressSpace value) -> std::string { + switch (value) { + case AddressSpace::device: + return "uint"; + case AddressSpace::threadgroup: + return "ushort"; + } + }; + + enum class Action { + load, + store, + }; + + struct MemoryAccessDescriptor { + std::optional action; + std::optional addressSpace; + std::optional decodingBF16; + int64_t indentationSpaceCount = 0; + }; + + auto createMemoryAccess = + [=](MemoryAccessDescriptor descriptor) -> std::string { + CCV_NNC_MFA_PRECONDITION(descriptor.action.has_value()); + CCV_NNC_MFA_PRECONDITION(descriptor.addressSpace.has_value()); + CCV_NNC_MFA_PRECONDITION(descriptor.decodingBF16.has_value()); + auto action = descriptor.action.value(); + auto addressSpace = descriptor.addressSpace.value(); + auto decodingBF16 = descriptor.decodingBF16.value(); + std::string indentation(descriptor.indentationSpaceCount, ' '); + + // Determine the arguments. + std::vector arguments; + auto pointerArgument = [=](std::string dataType) { + if (action == Action::load) { + return "const " + keyword(addressSpace) + " " + dataType + " *src"; + } else { + return keyword(addressSpace) + " " + dataType + " *dst"; + } + }; + if (decodingBF16) { + arguments.push_back(pointerArgument("bfloat")); + } else { + arguments.push_back(pointerArgument("U")); + } + arguments.push_back(offsetType(addressSpace) + " elements_per_row"); + arguments.push_back("ushort2 matrix_origin"); + arguments.push_back("bool transpose_matrix = false"); + + // Create the warning comment. + std::string output = ""; + if (decodingBF16) { + output += indentation + "// WARNING: 'T' must be 'float'.\n"; + } else { + output += indentation + "template \n"; + } + + // Create the function signature. + output += indentation + "METAL_FUNC void"; + if (action == Action::load) { + output += " load"; + } else { + output += " store"; + } + if (decodingBF16) { + output += "_bfloat"; + } + output += "("; + for (int64_t it = 0; it < arguments.size(); ++it) { + int64_t argumentID = it; + std::string argument = arguments[argumentID]; + + output += argument; + if (argumentID < arguments.size() - 1) { + output += ", "; + } + } + output += ") {\n"; + + auto createAddress = + [=](bool transposed, int64_t offset) -> std::string { + auto lineY = offsetType(addressSpace) + "(matrix_origin.y)"; + auto lineX = "matrix_origin.x + " + std::to_string(offset); + lineX = offsetType(addressSpace) + "(" + lineX + ")"; + + if (transposed) { + return lineX + " * elements_per_row + " + lineY; + } else { + return lineY + " * elements_per_row + " + lineX; + } + }; + + auto createTwoPartAccess = + [=](bool transposed) -> std::vector { + // Generate the addresses. + std::vector lines; + for (int64_t laneID = 0; laneID < 2; ++laneID) { + lines.push_back + (offsetType(addressSpace) + " address" + std::to_string(laneID) + + " = " + createAddress(transposed, laneID)); + } + + if (action == Action::load) { + if (decodingBF16) { + lines.push_back("bfloat memoryForm0 = src[address0]"); + lines.push_back("bfloat memoryForm1 = src[address1]"); + } else { + lines.push_back("U memoryForm0 = src[address0]"); + lines.push_back("U memoryForm1 = src[address1]"); + } + } + + if (action == Action::load) { + if (decodingBF16) { + // Separate the loading logic from the decoding logic for clarity. + lines.push_back + (""); + + // BF16 decoding logic. + lines.push_back + ("bfloat4 registerForm = *(thread bfloat4*)(thread_elements())"); + lines.push_back + ("registerForm[1] = memoryForm0"); + lines.push_back + ("registerForm[3] = memoryForm1"); + lines.push_back + ("((thread bfloat4*)thread_elements())[0] = registerForm"); + } else { + // Perform a type cast natively supported by the hardware. + lines.push_back + ("((thread T*)thread_elements())[0] = T(memoryForm0)"); + lines.push_back + ("((thread T*)thread_elements())[1] = T(memoryForm1)"); + } + } else { + if (decodingBF16) { + // BF16 encoding logic. + lines.push_back + ("bfloat4 registerForm = *(thread bfloat4*)(thread_elements())"); + lines.push_back + ("registerForm[2] = registerForm[1]"); + } else { + // Type casts supported natively by the hardware. + lines.push_back + ("T registerForm0 = ((thread T*)thread_elements())[0]"); + lines.push_back + ("T registerForm1 = ((thread T*)thread_elements())[1]"); + } + } + + if (action == Action::store) { + if (decodingBF16) { + lines.push_back("dst[address0] = registerForm[2]"); + lines.push_back("dst[address1] = registerForm[3]"); + } else { + lines.push_back("dst[address0] = U(registerForm0)"); + lines.push_back("dst[address1] = U(registerForm1)"); + } + } + return lines; + }; + + auto createOnePartAccess = + [=]() -> std::vector { + std::vector lines; + { + auto address = createAddress(false, 0); + lines.push_back("auto combinedAddress = " + address); + } + if (action == Action::load) { + if (decodingBF16) { + lines.push_back + ("bfloat2 memoryForm = *(const " + + keyword(addressSpace) + " packed_bfloat2*)(src + combinedAddress)"); + + // Separate the loading logic from the decoding logic for clarity. + lines.push_back + (""); + + // BF16 decoding logic. + lines.push_back + ("bfloat4 registerForm = *(thread bfloat4*)(thread_elements())"); + lines.push_back + ("((thread float*)®isterForm)[1] = *(thread float*)(&memoryForm)"); + lines.push_back + ("((thread bfloat*)®isterForm)[1] = memoryForm[0]"); + lines.push_back + ("((thread bfloat4*)thread_elements())[0] = registerForm"); + } else { + lines.push_back + ("vec memoryForm = *(const " + + keyword(addressSpace) + " vec*)(src + combinedAddress)"); + lines.push_back + ("*(thread_elements()) = vec(memoryForm)"); + } + } else { + if (decodingBF16) { + // BF16 encoding logic. + lines.push_back + ("bfloat4 registerForm = *(thread bfloat4*)(thread_elements())"); + lines.push_back + ("registerForm[2] = registerForm[1]"); + lines.push_back + ("float memoryForm = ((thread float*)®isterForm)[1]"); + lines.push_back + ("*(" + keyword(addressSpace) + " float*)" + + "(dst + combinedAddress) = memoryForm"); + } else { + lines.push_back + ("vec registerForm = *(thread_elements())"); + lines.push_back + ("*(" + keyword(addressSpace) + " vec*)" + + "(dst + combinedAddress) = vec(registerForm)"); + } + } + return lines; + }; + + auto insertBlockContents = + [=](std::vector& body, std::vector block) { + for (std::string line : block) { + // Check whether all characters are whitespace. + bool allCharactersWhitespace = true; + for (int8_t character : line) { + if (isspace(character)) { + + } else { + allCharactersWhitespace = false; + } + } + + // Branch on the result of this check. + if (allCharactersWhitespace) { + body.push_back(" "); + } else { + body.push_back(" " + line + ";"); + } + } + }; + + // Determine the lines of the 'if' block. + std::vector body; + body.push_back("if (transpose_matrix) {"); + insertBlockContents(body, createTwoPartAccess(true)); + + // Determine the lines of the 'else' block. + if (decodingBF16) { + std::vector blockContents; + if (action == Action::load) { + blockContents = createOnePartAccess(); + } else { + blockContents = createTwoPartAccess(false); + } + + body.push_back("} else {"); + insertBlockContents(body, blockContents); + body.push_back("}"); + } else { + body.push_back("} else if (elements_per_row % 2 != 0) {"); + insertBlockContents(body, createTwoPartAccess(false)); + body.push_back("} else {"); + insertBlockContents(body, createOnePartAccess()); + body.push_back("}"); + } + + // Create the function body. + for (std::string line : body) { + output += indentation + " " + line + "\n"; + } + output += indentation + "}\n"; + return output; + }; + + // Add the first section of the shader. + std::string output; + output += R"( +// -*- Metal -*- +//===-- metal_simdgroup_matrix_storage ------------------------------------===// +// Copyright (c) 2024 Philip Turner. See MIT LICENSE +//===----------------------------------------------------------------------===// + +#ifndef __METAL_SIMDGROUP_MATRIX_STORAGE +#define __METAL_SIMDGROUP_MATRIX_STORAGE + +// The layout of threads within a SIMD matrix. +// +// 0 0 1 1 8 8 9 9 +// 2 2 3 3 10 10 11 11 +// 4 4 5 5 12 12 13 13 +// 6 6 7 7 14 14 15 15 +// 16 16 17 17 24 24 25 25 +// 18 18 19 19 26 26 27 27 +// 20 20 21 21 28 28 29 29 +// 22 22 23 23 30 30 31 31 +// +// This is Morton order, a method for coalescing data accesses. It is used +// in a variety of contexts, from ray tracing acceleration structures, to +// nodal-point Laplacians, to sorting large lattices of atoms. +// +// Source: https://patents.google.com/patent/US11256518B2 +METAL_FUNC static ushort2 morton_order(ushort thread_index_in_simdgroup) { + ushort lane_id = thread_index_in_simdgroup; + ushort quad_id = lane_id / 4; + + constexpr ushort QUADRANT_SPAN_M = 4; + constexpr ushort THREADS_PER_QUADRANT = 8; + ushort M_floor_of_quadrant = (quad_id / 4) * QUADRANT_SPAN_M; + ushort M_in_quadrant = (lane_id / 2) % (THREADS_PER_QUADRANT / 2); + ushort M_in_simd = M_floor_of_quadrant + M_in_quadrant; + + ushort N_floor_of_quadrant = (quad_id & 2) * 2; // 0 or 4 + ushort N_in_quadrant = (lane_id % 2) * 2; // 0 or 2 + ushort N_in_simd = N_floor_of_quadrant + N_in_quadrant; + + return ushort2(N_in_simd, M_in_simd); +} + +#pragma METAL internals : enable +namespace metal +{ + template + struct simdgroup_matrix_storage { + typedef vec storage_type; + + storage_type t; + + METAL_FUNC thread vec* thread_elements() thread { + return reinterpret_cast*>(&t); + } + + METAL_FUNC simdgroup_matrix_storage() thread = default; + + METAL_FUNC simdgroup_matrix_storage(vec thread_elements) thread { + *(this->thread_elements()) = thread_elements; + } + + METAL_FUNC static device T* apply_offset(device T *src, uint elements_per_row, uint2 matrix_origin, bool transpose_matrix = false) { + if (transpose_matrix) { + return src + ulong(matrix_origin.x * elements_per_row) + matrix_origin.y; + } else { + return src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x; + } + } + + METAL_FUNC static threadgroup T* apply_offset(threadgroup T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { + if (transpose_matrix) { + return src + matrix_origin.x * elements_per_row + matrix_origin.y; + } else { + return src + matrix_origin.y * elements_per_row + matrix_origin.x; + } + } + +)"; + + MemoryAccessDescriptor desc; + desc.indentationSpaceCount = 4; + + std::vector actions = { Action::load, Action::store }; + std::vector addressSpaces = { + AddressSpace::device, AddressSpace::threadgroup + }; + std::vector decodingBF16s = { false, true }; + for (auto action : actions) { + for (auto addressSpace : addressSpaces) { + for (auto decodingBF16 : decodingBF16s) { + if (!BF16 && decodingBF16) { // Don't need to output BF16 related methods. + continue; + } + desc.action = action; + desc.addressSpace = addressSpace; + + desc.decodingBF16 = decodingBF16; + output += createMemoryAccess(desc); + output += "\n"; + } + } + } + // Add the last section of the header. + output += R"( + template + METAL_FUNC void multiply(simdgroup_matrix_storage a, simdgroup_matrix_storage b, bool accumulate = true) { + if (!accumulate) { + *(thread_elements()) = vec(0); + } + t = __metal_simdgroup_matrix_8x8_multiply_accumulate(a.t, b.t, t, typename simdgroup_matrix_storage::storage_type()); + } + }; +} // namespace metal +#pragma METAL internals : disable + +#endif // __METAL_SIMDGROUP_MATRIX_STORAGE + +)"; + return output; +} diff --git a/lib/nnc/mfa/v2/GEMMHeaders.hpp b/lib/nnc/mfa/v2/GEMMHeaders.hpp new file mode 100644 index 000000000..6e1a4deda --- /dev/null +++ b/lib/nnc/mfa/v2/GEMMHeaders.hpp @@ -0,0 +1,27 @@ +#ifndef GEMMHeaders_hpp +#define GEMMHeaders_hpp + +#include + +/// Create the source code for the 'metal\_simdgroup\_event' header. +/// +/// I may have found the hardware bug with async copies on M1. If you shoot +/// off an async copy, you need to read from its contents later in the +/// the shader. Otherwise, something inside the hardware (like a +/// DispatchSemaphore) will be waiting indefinitely to be notified. The bug +/// is a bit flaky, and only shows up for certain problem configurations. The +/// side effects are catastrophic; the GPU might freeze up until the computer +/// reboots. +/// +/// Workaround: if an async copy from device -> threadgroup is launched, +/// guarantee that both: +/// - The threadgroup will enter another `threadgroup_barrier` before the end of +/// the kernel. +/// - The results of the async copy will be read from. This means at least one +/// thread must dereference a pointer within the region of threadgroup memory. +std::string createMetalSimdgroupEvent(); + +/// Create the source code for the 'metal\_simdgroup\_matrix\_storage' header. +std::string createMetalSimdgroupMatrixStorage(bool BF16); + +#endif /* GEMMHeaders_hpp */ diff --git a/lib/nnc/mfa/v2/GEMMKernel.cpp b/lib/nnc/mfa/v2/GEMMKernel.cpp new file mode 100644 index 000000000..ba96f69fd --- /dev/null +++ b/lib/nnc/mfa/v2/GEMMKernel.cpp @@ -0,0 +1,952 @@ +#include "GEMMKernel.hpp" +#include "GEMMHeaders.hpp" +#include "CodeWriter.hpp" +#include "../ccv_nnc_mfa.hpp" + +#include + +std::string GEMMKernel::memoryName(char operand) const noexcept { + switch (operand) { + case 'A': + return memoryPrecisions.A.name(); + case 'B': + return memoryPrecisions.B.name(); + case 'C': + return memoryPrecisions.C.name(); + case 'S': + return memoryPrecisions.bias.name(); + default: + return ""; + } +} + +std::string GEMMKernel::registerName(char operand) const noexcept { + switch (operand) { + case 'A': + return registerPrecisions.A.name(); + case 'B': + return registerPrecisions.B.name(); + case 'C': + return registerPrecisions.C.name(); + case 'S': + return registerPrecisions.bias.name(); + default: + return ""; + } +} + +unsigned short GEMMKernel::threadgroupMemoryAllocationValue() const noexcept { + unsigned short blockBytesA = blockBytes('A'); + unsigned short blockBytesB = blockBytes('B'); + unsigned short blockBytesC = blockBytes('C'); + return std::max((unsigned short)(blockBytesA + blockBytesB), blockBytesC); +} + +bool GEMMKernel::transposed(char operand) const noexcept { + switch (operand) { + case 'A': + return transposeState[0]; + case 'B': + return transposeState[1]; + case 'C': + return false; + default: + return false; + } +} + +std::string GEMMKernel::leadingDimension(char operand) const noexcept { + return std::string(1, operand) + "_leading_dimension"; +} + +unsigned short GEMMKernel::leadingBlockDimension(char operand) const noexcept { + switch (operand) { + case 'A': + return leadingBlockDimensions[0]; + case 'B': + return leadingBlockDimensions[1]; + case 'C': + return leadingBlockDimensions[2]; + default: + return 0; + } +} + +unsigned short GEMMKernel::trailingBlockDimension(char operand) const noexcept { + auto chooseTrailingBlockDimension = + [=](bool transposeState, unsigned short untransposedRows, unsigned short untransposedColumns) -> unsigned short { + if (transposeState) { + return untransposedColumns; + } else { + return untransposedRows; + } + }; + + switch (operand) { + case 'A': + return chooseTrailingBlockDimension( + transposed('A'), blockDimensions[0], blockDimensions[2]); + case 'B': + return chooseTrailingBlockDimension( + transposed('B'), blockDimensions[2], blockDimensions[1]); + case 'C': + return chooseTrailingBlockDimension( + transposed('C'), blockDimensions[0], blockDimensions[1]); + default: + return 0; + } +} + +unsigned short GEMMKernel::blockBytes(char operand) const noexcept { + unsigned short output = 1; + output *= leadingBlockDimension(operand); + output *= trailingBlockDimension(operand); + + GEMMOperandPrecision memoryPrecision; + switch (operand) { + case 'A': + memoryPrecision = memoryPrecisions.A; + case 'B': + memoryPrecision = memoryPrecisions.B; + case 'C': + memoryPrecision = memoryPrecisions.C; + } + output *= memoryPrecision.size(); + return output; +} + +GEMMKernel::GEMMKernel(GEMMKernelDescriptor descriptor, MTL::Device *const device) { + blockDimensions = descriptor.blockDimensions; + memoryPrecisions = descriptor.memoryPrecisions; + registerPrecisions = descriptor.registerPrecisions; + splits = descriptor.splits; + transposeState = descriptor.transposeState; + preferAsyncLoad = descriptor.preferAsyncLoad; + preferAsyncStore = descriptor.preferAsyncStore; + useBias = descriptor.useBias; + threadgroupSize = 32 * splits[0] * splits[1]; + + // Validate the correctness of register precisions. + auto checkOperandPair = + [=](GEMMOperandPrecision memory, GEMMOperandPrecision register_) -> bool { + // Truth table: + // + // memory | register | valid | + // ------ | -------- | ----- | + // FP32 | FP32 | yes | + // FP32 | FP16 | no | + // FP32 | BF16 | no | + // FP16 | FP32 | yes | + // FP16 | FP16 | yes | + // FP16 | BF16 | no | + // BF16 | FP32 | yes | + // BF16 | FP16 | no | + // BF16 | BF16 | yes | + // + // Optimized form of the logic: + // + // If the register precision matches the memory precision, + // return true + // If the register precision equals FP32, + // return true + // Otherwise, + // return false + // + // The logic statements will change if you introduce custom quantized + // formats. The truth table will grow exponentially. You'll need to add + // more restrictions on accepted pairs to overcome the combinatorial + // explosion. + if (register_ == memory) { + return true; + } else if (register_.value == GEMMOperandPrecision::FP32) { + return true; + } else { + return false; + } + }; + + CCV_NNC_MFA_PRECONDITION + (checkOperandPair(memoryPrecisions.A, registerPrecisions.A)); + CCV_NNC_MFA_PRECONDITION + (checkOperandPair(memoryPrecisions.B, registerPrecisions.B)); + CCV_NNC_MFA_PRECONDITION + (checkOperandPair(memoryPrecisions.C, registerPrecisions.C)); + if (registerPrecisions.C == GEMMOperandPrecision::BF16) { + // BF16 has too few mantissa bits to be an accurate accumulator. In + // addition, switching from FP32 accumulator to BF16 accumulator slows + // down execution speed on both M1/M2 and M3+. + CCV_NNC_MFA_PRECONDITION(false); + } + + // Declare the size of M and N within a register allocation. + registerM = blockDimensions[0] / splits[0]; + registerN = blockDimensions[1] / splits[1]; + + // Retrieve the "padded" block dimensions, otherwise compute analytically + // from the true block dimensions. + auto chooseLeadingBlockDimension = + [=](unsigned short specifiedLeading, bool transposeState, unsigned short untransposedRows, unsigned short untransposedColumns) -> unsigned short { + unsigned short expectedLeading; + if (transposeState) { + expectedLeading = untransposedRows; + } else { + expectedLeading = untransposedColumns; + } + + unsigned short actualLeading; + if (specifiedLeading != 0) { + if (specifiedLeading < expectedLeading) { + CCV_NNC_MFA_PRECONDITION(false && "Leading block dimension was too small."); + } + actualLeading = specifiedLeading; + } else { + actualLeading = expectedLeading; + } + + return actualLeading; + }; + + leadingBlockDimensions[0] = chooseLeadingBlockDimension( + descriptor.leadingBlockDimensions.value_or(simd::ushort3())[0], transposeState[0], + blockDimensions[0], blockDimensions[2]); + leadingBlockDimensions[1] = chooseLeadingBlockDimension( + descriptor.leadingBlockDimensions.value_or(simd::ushort3())[1], transposeState[1], + blockDimensions[2], blockDimensions[1]); + leadingBlockDimensions[2] = chooseLeadingBlockDimension( + descriptor.leadingBlockDimensions.value_or(simd::ushort3())[2], false, + blockDimensions[0], blockDimensions[1]); + + source = createSource(); + + threadgroupMemoryAllocation = threadgroupMemoryAllocationValue(); + + // Compile the shader source. + { + auto string = NS::String::string(source.c_str(), NS::UTF8StringEncoding); + NS::Error* error = nil; + library = NS::TransferPtr(device->newLibrary(string, nil, &error)); + CCV_NNC_MFA_CHECK_ERROR(error); + } +} + +#pragma mark - Source + +std::string GEMMKernel::createSource() const noexcept { + CodeWriter source; + + bool injectBF16Methods = (memoryPrecisions.A == GEMMOperandPrecision::BF16) || (memoryPrecisions.B == GEMMOperandPrecision::BF16) || (memoryPrecisions.C == GEMMOperandPrecision::BF16) || (memoryPrecisions.bias == GEMMOperandPrecision::BF16); + + // Inject the contents of the headers. + source += createMetalSimdgroupEvent() + "\n"; + source += createMetalSimdgroupMatrixStorage(injectBF16Methods) + "\n"; + source += "using namespace metal;\n\n"; + + source.SetValue("TRANSPOSE_STATE_A", std::to_string(bool(transposeState[0]))); + source.SetValue("TRANSPOSE_STATE_B", std::to_string(bool(transposeState[1]))); + source.SetValue("TRANSPOSE_STATE_BIAS", std::to_string(bool(transposeState[2]))); + source.SetValue("BLOCK_DIMENSIONS_M", std::to_string(blockDimensions[0])); + source.SetValue("BLOCK_DIMENSIONS_N", std::to_string(blockDimensions[1])); + source.SetValue("BLOCK_DIMENSIONS_K", std::to_string(blockDimensions[2])); + source.SetValue("REGISTER_M", std::to_string(registerM)); + source.SetValue("REGISTER_N", std::to_string(registerN)); + + source += createConstants(); + + source.SetValue("MEMORY_NAME_A", memoryName('A')); + source.SetValue("MEMORY_NAME_B", memoryName('B')); + source.SetValue("MEMORY_NAME_C", memoryName('C')); + source.SetValue("MEMORY_NAME_BIAS", memoryName('S')); + source.SetValue("REGISTER_NAME_A", registerName('A')); + source.SetValue("REGISTER_NAME_B", registerName('B')); + source.SetValue("REGISTER_NAME_C", registerName('C')); + source.SetValue("REGISTER_NAME_BIAS", registerName('S')); + source.SetValue("SPLITS_N", std::to_string(splits[1])); + + createUtilities(&source); + + source += R"( + +// Metal function arguments. +// +// A: the left-hand side matrix +// - dimensions: M x K +// K x M (transposed) +// - memory precision: memA +// - register precision: regA +// +// B: the right-hand side matrix +// - dimensions: K x N +// N x K (transposed) +// - memory precision: memB +// - register precision: regB +// +// C: the output matrix, alternatively the dot product accumulator +// - dimensions: M x N +// - memory precision: memC +// - register precision: regC +// +// threadgroup_block: the chunk of threadgroup memory allocated at runtime +// - ideally 10 KB or less +// - precision: void/8-bit integer to make the pointer arithmetic more legible + +kernel void gemm(device {{MEMORY_NAME_A}} *A [[buffer(0)]], + device {{MEMORY_NAME_B}} *B [[buffer(1)]], + device {{MEMORY_NAME_C}} *C [[buffer(2)]], +)"; + if (useBias) { + source += R"( + device {{MEMORY_NAME_BIAS}} *bias [[buffer(3)]], +)"; + } +source += R"( + threadgroup uchar *threadgroup_block [[threadgroup(0)]], + + uint3 gid [[threadgroup_position_in_grid]], + ushort sidx [[simdgroup_index_in_threadgroup]], + ushort lane_id [[thread_index_in_simdgroup]]) +{ + if (batched) { + A = A + A_batch_stride * gid.z; + B = B + B_batch_stride * gid.z; + C = C + C_batch_stride * gid.z; +)"; + if (useBias) { + source += R"( + bias = bias + bias_batch_stride * gid.z; +)"; + } +source += R"( + } + ushort2 sid(sidx % {{SPLITS_N}}, sidx / {{SPLITS_N}}); + ushort2 morton_offset = morton_order(lane_id); + + // Return early if the SIMD is out of bounds. + // + // There could be some threadgroups where the matrix edge cuts straight + // through the middle of the block. SIMDs on the right or bottom of the + // dividing line must be stopped from causing out-of-bounds accesses. This is + // the reason for the early exit. + uint M_offset = gid.y * M_group; + uint N_offset = gid.x * N_group; + if (M_offset + sid.y * {{REGISTER_M}} >= M || + N_offset + sid.x * {{REGISTER_N}} >= N) { + return; + } + ushort2 offset_in_group(sid.x * {{REGISTER_N}} + morton_offset.x, + sid.y * {{REGISTER_M}} + morton_offset.y); + + // Shift the matrix block within bounds, if possible. + if ((M_shift != 0) && (gid.y * M_group >= M_edge)) { + M_offset -= M_shift; + } + if ((N_shift != 0) && (gid.x * N_group >= N_edge)) { + N_offset -= N_shift; + } + +)"; + + createInitializeC(&source); + + createMultiplyIterations(&source); + + createStoreC(&source); + + source += "}\n\n"; + + return source.ToString(); +} + +std::string GEMMKernel::createConstants() const noexcept { + std::string constants = R"( +// Dimensions of each matrix. +// - Limitations to matrix size: +// - 2^32 in each dimension (M/N/K). +// - Extending to 2^64 may require changing 'uint' to 'ulong'. There is a +// good chance this will significantly degrade performance, and require +// changing the data type of several variables that process addresses. The +// client is responsible for ensuring correctness and performance with +// matrices spanning several billion elements in one direction. +// - The matrix dimensions must be known at compile time, via function +// constants. Dynamic matrix shapes are beyond the scope of this reference +// implementation. Dynamic shapes cause a non-negligible regression to +// shader execution speed. However, they could minimize a compilation +// latency bottleneck in some use cases. +// - Limitations to batch size: +// - Dictated by how the client modifies the code to implement batching. +// - Dynamic batch shapes would likely not harm performance much. For example, +// someone could enter an array of pointers/memory offsets to different +// matrices in the batch. Each slice of a 3D thread grid could read a +// different pointer from memory, and use that pointer as the A/B/C matrix. +// Another approach is to restrict the input format, so all matrices are +// stored contiguously in memory. Then, the memory offset could be computed +// analytically from matrix size and the Z dimension in a 3D thread grid. +// +// Another note: +// - The rows of the matrix must be contiguous in memory. Supporting strides +// that differ from the actual matrix dimensions should not be difficult, but +// it is out of scope for this reference kernel. +constant uint M [[function_constant(0)]]; +constant uint N [[function_constant(1)]]; +constant uint K [[function_constant(2)]]; + +// Specify the leading dimensions at PSO creation time. +constant uint A_leading_dimension [[function_constant(5)]]; +constant uint B_leading_dimension [[function_constant(6)]]; +constant uint C_leading_dimension [[function_constant(7)]]; + +// Whether to load the previous value of C, and add it to the accumulator. +constant bool load_previous_C [[function_constant(10)]]; + +// Specify the batch / batch strides at PSO creation time. +constant bool batched [[function_constant(11)]]; + +constant uint A_batch_stride [[function_constant(15)]]; +constant uint B_batch_stride [[function_constant(16)]]; +constant uint C_batch_stride [[function_constant(17)]]; +constant uint bias_batch_stride [[function_constant(18)]]; + +// Whether each matrix is transposed. +constant bool A_trans = {{TRANSPOSE_STATE_A}}; +constant bool B_trans = {{TRANSPOSE_STATE_B}}; +)"; + if (useBias) { + constants += R"( +constant bool bias_trans = {{TRANSPOSE_STATE_BIAS}}; +)"; + } + constants += R"( + +// Define the memory layout of the matrix block. +constant ushort M_group = {{BLOCK_DIMENSIONS_M}}; +constant ushort N_group = {{BLOCK_DIMENSIONS_N}}; +constant ushort K_group = {{BLOCK_DIMENSIONS_K}}; + +// Thresholds that mark the matrix edge. +constant uint M_edge = M - (M % M_group); +constant uint N_edge = N - (N % N_group); + +// Find the number of elements in the final block. If the matrix +// dimensions are perfectly divisibly by block dimensions, we don't want +// this value to be zero. The final block is a full block. +constant ushort M_remainder = (M % {{REGISTER_M}} == 0) + ? {{REGISTER_M}} : M % {{REGISTER_M}}; +constant ushort N_remainder = (N % {{REGISTER_N}} == 0) + ? {{REGISTER_N}} : N % {{REGISTER_N}}; +constant ushort K_remainder = (K % K_group == 0) + ? K_group : K % K_group; +constant ushort K_remainder_padded = (K_remainder + 7) / 8 * 8; + +// Shift the final block, so it doesn't access out-of-bounds memory. +constant ushort M_shift = (M < M_group) ? 0 : {{REGISTER_M}} - M_remainder; +constant ushort N_shift = (N < N_group) ? 0 : {{REGISTER_N}} - N_remainder; + +)"; + return constants; +} + +void GEMMKernel::createUtilities(CodeWriter *const source) const noexcept { + // Add the utility functions. + *source += R"( + +// Indexes into an array of registers. +// +// Calls to this function are expected to be evaluated at compile time. The +// array indices transform into register offsets, which are embedded into the +// assembly code. +template +METAL_FUNC thread simdgroup_matrix_storage* get_sram( + thread simdgroup_matrix_storage *sram, + ushort sram_leading_dim, + ushort2 matrix_origin +) { + return sram + (matrix_origin.y / 8) * (sram_leading_dim / 8) + (matrix_origin.x / 8); +} +)"; + + std::string createMultiply = R"( + +// One multiply-accumulate loop iteration, or 8 dot products. +METAL_FUNC void multiply_accumulate( +const {{ADDRESS_SPACE}} {{MEMORY_NAME_A}} *A_src, +const {{ADDRESS_SPACE}} {{MEMORY_NAME_B}} *B_src, +thread simdgroup_matrix_storage<{{REGISTER_NAME_A}}> *A_sram, +thread simdgroup_matrix_storage<{{REGISTER_NAME_B}}> *B_sram, +thread simdgroup_matrix_storage<{{REGISTER_NAME_C}}> *C_sram, +ushort k +) { +#pragma clang loop unroll(full) +for (ushort m = 0; m < {{REGISTER_M}}; m += 8) { + ushort2 origin(0, m); + auto A = get_sram(A_sram, 8, origin); + A->{{LOAD_FUNCTION_A}}(A_src, {{LEADING_DIMENSION_A}}, ushort2(k, m), A_trans); +} +#pragma clang loop unroll(full) +for (ushort n = 0; n < {{REGISTER_N}}; n += 8) { + ushort2 origin(n, 0); + auto B = get_sram(B_sram, {{REGISTER_N}}, origin); + B->{{LOAD_FUNCTION_B}}(B_src, {{LEADING_DIMENSION_B}}, ushort2(n, k), B_trans); +} +#pragma clang loop unroll(full) +for (ushort m = 0; m < {{REGISTER_M}}; m += 8) { +#pragma clang loop unroll(full) + for (ushort n = 0; n < {{REGISTER_N}}; n += 8) { + auto A = get_sram(A_sram, 8, ushort2(0, m)); + auto B = get_sram(B_sram, {{REGISTER_N}}, ushort2(n, 0)); + auto C = get_sram(C_sram, {{REGISTER_N}}, ushort2(n, m)); + C->multiply(*A, *B); + } +} +} + +)"; + + // Add the utility functions for the multiply-accumulate inner loop. + if (memoryPrecisions.A == GEMMOperandPrecision::BF16 && registerPrecisions.A == GEMMOperandPrecision::FP32) { + source->SetValue("LOAD_FUNCTION_A", "load_bfloat"); + } else { + source->SetValue("LOAD_FUNCTION_A", "load"); + } + if (memoryPrecisions.B == GEMMOperandPrecision::BF16 && registerPrecisions.B == GEMMOperandPrecision::FP32) { + source->SetValue("LOAD_FUNCTION_B", "load_bfloat"); + } else { + source->SetValue("LOAD_FUNCTION_B", "load"); + } + + source->SetValue("ADDRESS_SPACE", "device"); + source->SetValue("LEADING_DIMENSION_A", leadingDimension('A')); + source->SetValue("LEADING_DIMENSION_B", leadingDimension('B')); + + *source += createMultiply; + + source->SetValue("ADDRESS_SPACE", "threadgroup"); + source->SetValue("LEADING_DIMENSION_A", std::to_string(leadingBlockDimensions[0])); + source->SetValue("LEADING_DIMENSION_B", std::to_string(leadingBlockDimensions[1])); + *source += createMultiply; +} + +#pragma mark - Caching + +void GEMMKernel::createInitializeC(CodeWriter *source) const noexcept { + source->SetValue("REGISTER_M_8_REGISTER_N_8", std::to_string((registerM / 8) * (registerN / 8))); + *source += R"( + + simdgroup_matrix_storage<{{REGISTER_NAME_C}}> C_sram[ + {{REGISTER_M_8_REGISTER_N_8}}]; + + if (load_previous_C) { + )"; + createLoadC(source); + *source += R"( + } else { +)"; + if (useBias) { + if (true) { // TODO: figure why on M3 / M4 this is faster. preferAsyncLoad) { + source->SetValue("DIRECT_BIAS_ACCESS_CONDITION", "false"); + } else { + source->SetValue("DIRECT_BIAS_ACCESS_CONDITION", "(M >= M_group) && (N >= N_group)"); + } + if (memoryPrecisions.bias == GEMMOperandPrecision::BF16 && registerPrecisions.bias == GEMMOperandPrecision::FP32) { + source->SetValue("LOAD_FUNCTION_BIAS", "load_bfloat"); + } else { + source->SetValue("LOAD_FUNCTION_BIAS", "load"); + } + std::string declareBiasLocationDevice; + std::string declareBiasLocationThreadgroup; + if (transposeState[2]) { + declareBiasLocationDevice = R"( + uint2 bias_offset(uint(M_offset + offset_in_group.y), 0); + auto bias_src = + simdgroup_matrix_storage<{{MEMORY_NAME_BIAS}}>::apply_offset( + bias, 0, bias_offset); +)"; + declareBiasLocationThreadgroup = R"( + ushort2 bias_block_offset(ushort(offset_in_group.y), 0); + auto bias_src = (threadgroup {{MEMORY_NAME_BIAS}}*)(threadgroup_block); + bias_src = simdgroup_matrix_storage<{{MEMORY_NAME_BIAS}}>::apply_offset( + bias_src, 0, bias_block_offset); +)"; + } else { + declareBiasLocationDevice = R"( + uint2 bias_offset(uint(N_offset + offset_in_group.x), 0); + auto bias_src = + simdgroup_matrix_storage<{{MEMORY_NAME_BIAS}}>::apply_offset( + bias, 0, bias_offset); +)"; + declareBiasLocationThreadgroup = R"( + ushort2 bias_block_offset(ushort(offset_in_group.x), 0); + auto bias_src = (threadgroup {{MEMORY_NAME_BIAS}}*)(threadgroup_block); + bias_src = simdgroup_matrix_storage<{{MEMORY_NAME_BIAS}}>::apply_offset( + bias_src, 0, bias_block_offset); +)"; + } + std::string loadBiasLoop; + if (transposeState[2]) { + loadBiasLoop = R"( + #pragma clang loop unroll(full) + for (ushort m = 0; m < {{REGISTER_M}}; m += 8) { + simdgroup_matrix_storage<{{REGISTER_NAME_BIAS}}> bias; + bias.{{LOAD_FUNCTION_BIAS}}( + bias_src, 0, ushort2(m, 0)); + bias.thread_elements()[0][1] = bias.thread_elements()[0][0]; + + #pragma clang loop unroll(full) + for (ushort n = 0; n < {{REGISTER_N}}; n += 8) { + vec<{{REGISTER_NAME_BIAS}}, 2> biasForm = *(bias.thread_elements()); + auto accumulatorForm = vec<{{REGISTER_NAME_C}}, 2>(biasForm); + + ushort2 origin(n, m); + auto C = get_sram(C_sram, {{REGISTER_N}}, origin); + *C = simdgroup_matrix_storage<{{REGISTER_NAME_C}}>(accumulatorForm); + } + } +)"; + } else { + loadBiasLoop = R"( + #pragma clang loop unroll(full) + for (ushort n = 0; n < {{REGISTER_N}}; n += 8) { + simdgroup_matrix_storage<{{REGISTER_NAME_BIAS}}> bias; + bias.{{LOAD_FUNCTION_BIAS}}( + bias_src, 0, ushort2(n, 0)); + + #pragma clang loop unroll(full) + for (ushort m = 0; m < {{REGISTER_M}}; m += 8) { + vec<{{REGISTER_NAME_BIAS}}, 2> biasForm = *(bias.thread_elements()); + auto accumulatorForm = vec<{{REGISTER_NAME_C}}, 2>(biasForm); + ushort2 origin(n, m); + auto C = get_sram(C_sram, {{REGISTER_N}}, origin); + *C = simdgroup_matrix_storage<{{REGISTER_NAME_C}}>(accumulatorForm); + } + } +)"; + } + *source += R"( + if ({{DIRECT_BIAS_ACCESS_CONDITION}}) { +)"; + *source += declareBiasLocationDevice; + *source += loadBiasLoop; + *source += R"( + } else { + if (sidx == 0) { + uint2 bias_offset(bias_trans ? M_offset : N_offset, 0); + auto bias_dst = (threadgroup {{MEMORY_NAME_BIAS}}*)(threadgroup_block); + auto bias_src = + simdgroup_matrix_storage<{{MEMORY_NAME_BIAS}}>::apply_offset( + bias, 0, bias_offset); + + ushort bias_tile_dimension = bias_trans + ? min(uint(M_group), M - M_offset) + : min(uint(N_group), N - N_offset); + + // Issue an async copy. + simdgroup_event event; + event.async_copy( + bias_dst, 1, ushort2(bias_tile_dimension, 1), + bias_src, 1, ushort2(bias_tile_dimension, 1)); + simdgroup_event::wait(1, &event); + } + threadgroup_barrier(mem_flags::mem_threadgroup); +)"; + *source += declareBiasLocationThreadgroup; + *source += loadBiasLoop; + *source += R"( + // Add a barrier, because you accessed the entries from threadgroup + // memory. + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } +)"; + } else { + *source += R"( + #pragma clang loop unroll(full) + for (ushort m = 0; m < {{REGISTER_M}}; m += 8) { + #pragma clang loop unroll(full) + for (ushort n = 0; n < {{REGISTER_N}}; n += 8) { + ushort2 origin(n, m); + auto C = get_sram(C_sram, {{REGISTER_N}}, origin); + *C = simdgroup_matrix_storage<{{REGISTER_NAME_C}}>(0); + } + } + } +)"; + } +} + +void GEMMKernel::createLoadC(CodeWriter *source) const noexcept { + if (memoryPrecisions.C == GEMMOperandPrecision::BF16 && registerPrecisions.C == GEMMOperandPrecision::FP32) { + source->SetValue("LOAD_FUNCTION_C", "load_bfloat"); + } else { + source->SetValue("LOAD_FUNCTION_C", "load"); + } + source->SetValue("LEADING_DIMENSION_C", leadingDimension('C')); + source->SetValue("LEADING_BLOCK_DIMENSIONS_C", std::to_string(leadingBlockDimensions[2])); + + if (preferAsyncStore) { + source->SetValue("DIRECT_ACCESS_CONDITION", "false"); + } else { + // In the vanilla GEMM kernel, the extra storing code can be optimized + // away at compile time. The compiler may allocate less registers, and + // occupancy may be greater. + std::string output = "(M >= M_group) && (N >= N_group)"; + + // When accumulate is supported, there are overlapping writes. We must + // sanitize the matrix edge with async copy. The optimization from + // the unified GEMM kernel cannot be applied. + // + // Ideally, a client implementation would add a GEMMKernelDescriptor + // property for whether in-place accumulation was enabled. When false, + // the statements below are not part of the direct-access condition. + // The code for loading C from memory would be elided at + // code-generation time. + // + // MFA has settled on a function constant to toggle accumulation. + output += " && (load_previous_C ? (M_offset == gid.y * M_group) : true)"; + output += " && (load_previous_C ? (N_offset == gid.x * N_group) : true)"; + source->SetValue("DIRECT_ACCESS_CONDITION", output); + } + + *source += R"( + +if ({{DIRECT_ACCESS_CONDITION}}) { + // Fast path for matrices that qualify. + uint2 C_offset(N_offset + offset_in_group.x, + M_offset + offset_in_group.y); + auto C_dst = simdgroup_matrix_storage<{{MEMORY_NAME_C}}>::apply_offset( + C, {{LEADING_DIMENSION_C}}, C_offset); + + // Write the accumulator to device memory. +#pragma clang loop unroll(full) + for (ushort m = 0; m < {{REGISTER_M}}; m += 8) { +#pragma clang loop unroll(full) + for (ushort n = 0; n < {{REGISTER_N}}; n += 8) { + ushort2 origin(n, m); + auto C = get_sram(C_sram, {{REGISTER_N}}, origin); + C->{{LOAD_FUNCTION_C}}(C_dst, {{LEADING_DIMENSION_C}}, origin); + } + } +} else { + // Slow path for when memory must be handled more carefully. + auto C_block = (threadgroup {{MEMORY_NAME_C}}*)(threadgroup_block); + auto C_block_dst = + simdgroup_matrix_storage<{{MEMORY_NAME_C}}>::apply_offset( + C_block, {{LEADING_BLOCK_DIMENSIONS_C}}, offset_in_group); + + // Launch the async copy from threadgroup to device memory. + if (sidx == 0) { + uint2 C_offset(N_offset, M_offset); + ushort2 C_tile(min(uint(N_group), N - C_offset.x), + min(uint(M_group), M - C_offset.y)); + auto C_dst = simdgroup_matrix_storage<{{MEMORY_NAME_C}}>::apply_offset( + C, {{LEADING_DIMENSION_C}}, C_offset); + + simdgroup_event event; + event.async_copy( + C_block, {{LEADING_BLOCK_DIMENSIONS_C}}, C_tile, + C_dst, {{LEADING_DIMENSION_C}}, C_tile); + simdgroup_event::wait(1, &event); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Read the accumulator from threadgroup memory. +#pragma clang loop unroll(full) + for (ushort m = 0; m < {{REGISTER_M}}; m += 8) { +#pragma clang loop unroll(full) + for (ushort n = 0; n < {{REGISTER_N}}; n += 8) { + ushort2 origin(n, m); + auto C = get_sram(C_sram, {{REGISTER_N}}, origin); + C->{{LOAD_FUNCTION_C}}( + C_block_dst, {{LEADING_BLOCK_DIMENSIONS_C}}, origin); + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); +} + +)"; +} + +void GEMMKernel::createStoreC(CodeWriter *source) const noexcept { + if (memoryPrecisions.C == GEMMOperandPrecision::BF16 && registerPrecisions.C == GEMMOperandPrecision::FP32) { + source->SetValue("STORE_FUNCTION_C", "store_bfloat"); + } else { + source->SetValue("STORE_FUNCTION_C", "store"); + } + + *source += R"( + +if ({{DIRECT_ACCESS_CONDITION}}) { + // Fast path for matrices that qualify. + uint2 C_offset(N_offset + offset_in_group.x, + M_offset + offset_in_group.y); + auto C_dst = simdgroup_matrix_storage<{{MEMORY_NAME_C}}>::apply_offset( + C, {{LEADING_DIMENSION_C}}, C_offset); + + // Write the accumulator to device memory. +#pragma clang loop unroll(full) + for (ushort m = 0; m < {{REGISTER_M}}; m += 8) { +#pragma clang loop unroll(full) + for (ushort n = 0; n < {{REGISTER_N}}; n += 8) { + ushort2 origin(n, m); + auto C = get_sram(C_sram, {{REGISTER_N}}, origin); + C->{{STORE_FUNCTION_C}}(C_dst, {{LEADING_DIMENSION_C}}, origin); + } + } +} else { + // Slow path for when memory must be handled more carefully. + auto C_block = (threadgroup {{MEMORY_NAME_C}}*)(threadgroup_block); + auto C_block_dst = + simdgroup_matrix_storage<{{MEMORY_NAME_C}}>::apply_offset( + C_block, {{LEADING_BLOCK_DIMENSIONS_C}}, offset_in_group); + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write the accumulator to threadgroup memory. +#pragma clang loop unroll(full) + for (ushort m = 0; m < {{REGISTER_M}}; m += 8) { +#pragma clang loop unroll(full) + for (ushort n = 0; n < {{REGISTER_N}}; n += 8) { + ushort2 origin(n, m); + auto C = get_sram(C_sram, {{REGISTER_N}}, origin); + C->{{STORE_FUNCTION_C}}( + C_block_dst, {{LEADING_BLOCK_DIMENSIONS_C}}, origin); + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Launch the async copy from threadgroup to device memory. + if (sidx == 0) { + uint2 C_offset(gid.x * N_group, gid.y * M_group); + ushort2 C_tile(min(uint(N_group), N - C_offset.x), + min(uint(M_group), M - C_offset.y)); + auto C_dst = simdgroup_matrix_storage<{{MEMORY_NAME_C}}>::apply_offset( + C, {{LEADING_DIMENSION_C}}, C_offset); + + // If we shift successfully, the garbage zone moves from the bottom right + // to the top left. + if ((M_shift != 0) || (N_shift != 0)) { + ushort2 C_block_shift(0, 0); + if ((M_shift != 0) && (C_offset.y >= M_edge)) { + C_block_shift.y = M_shift; + } + if ((N_shift != 0) && (C_offset.x >= N_edge)) { + C_block_shift.x = N_shift; + } + C_block = simdgroup_matrix_storage<{{MEMORY_NAME_C}}>::apply_offset( + C_block, {{LEADING_BLOCK_DIMENSIONS_C}}, C_block_shift); + } + + simdgroup_event event; + event.async_copy( + C_dst, {{LEADING_DIMENSION_C}}, C_tile, + C_block, {{LEADING_BLOCK_DIMENSIONS_C}}, C_tile); + } +} +)"; +} + +#pragma mark - Multiply + +void GEMMKernel::createMultiplyIterations(CodeWriter *source) const noexcept { + if (preferAsyncLoad) { + source->SetValue("ASYNC_ITERATIONS_START", "0"); + } else { + source->SetValue("ASYNC_ITERATIONS_START", "(K - (K % K_group))"); + } + source->SetValue("PADDED_CEILING_K", "(K + K_remainder_padded - K_remainder)"); + source->SetValue("LEADING_DIMENSION_A", leadingDimension('A')); + source->SetValue("LEADING_DIMENSION_B", leadingDimension('B')); + source->SetValue("LEADING_BLOCK_DIMENSIONS_A", std::to_string(leadingBlockDimensions[0])); + source->SetValue("LEADING_BLOCK_DIMENSIONS_B", std::to_string(leadingBlockDimensions[1])); + source->SetValue("BLOCK_BYTES_A", std::to_string(blockBytes('A'))); + source->SetValue("REGISTER_M_8", std::to_string(registerM / 8)); + source->SetValue("REGISTER_N_8", std::to_string(registerN / 8)); + + *source += R"( + +// Perform the iterations where async copy is avoided. +for (uint k = 0; k < {{ASYNC_ITERATIONS_START}}; k += 8) { + uint2 A_offset(k, M_offset); + uint2 B_offset(N_offset, k); + A_offset += uint2(morton_offset.x, offset_in_group.y); + B_offset += uint2(offset_in_group.x, morton_offset.y); + + auto A_src = simdgroup_matrix_storage<{{MEMORY_NAME_A}}>::apply_offset( + A, {{LEADING_DIMENSION_A}}, A_offset, A_trans); + auto B_src = simdgroup_matrix_storage<{{MEMORY_NAME_B}}>::apply_offset( + B, {{LEADING_DIMENSION_B}}, B_offset, B_trans); + + simdgroup_matrix_storage<{{REGISTER_NAME_A}}> A_sram[ + {{REGISTER_M_8}} * (8 / 8)]; + simdgroup_matrix_storage<{{REGISTER_NAME_B}}> B_sram[ + (8 / 8) * {{REGISTER_N_8}}]; + multiply_accumulate(A_src, B_src, + A_sram, B_sram, C_sram, 0); +} + +// Perform the iterations where async copy is used. +for (uint k = {{ASYNC_ITERATIONS_START}}; k < K; k += K_group) { + auto A_block = (threadgroup {{MEMORY_NAME_A}}*)( + threadgroup_block); + auto B_block = (threadgroup {{MEMORY_NAME_B}}*)( + threadgroup_block + {{BLOCK_BYTES_A}}); + + // Launch an async copy from device to threadgroup memory. + if (sidx == 0) { + uint2 A_offset(k, M_offset); + uint2 B_offset(N_offset, k); + auto A_src = simdgroup_matrix_storage<{{MEMORY_NAME_A}}>::apply_offset( + A, {{LEADING_DIMENSION_A}}, A_offset, A_trans); + auto B_src = simdgroup_matrix_storage<{{MEMORY_NAME_B}}>::apply_offset( + B, {{LEADING_DIMENSION_B}}, B_offset, B_trans); + + ushort M_tile_dimension = min(uint(M_group), M - M_offset); + ushort N_tile_dimension = min(uint(N_group), N - N_offset); + ushort K_tile_dimension = min(uint(K_group), K - k); + ushort K_tile_padded = min(uint(K_group), {{PADDED_CEILING_K}} - k); + + ushort2 A_tile_src(K_tile_dimension, M_tile_dimension); + ushort2 B_tile_src(N_tile_dimension, K_tile_dimension); + ushort2 A_tile_dst(K_tile_padded, M_tile_dimension); + ushort2 B_tile_dst(N_tile_dimension, K_tile_padded); + + simdgroup_event events[2]; + events[0].async_copy( + A_block, {{LEADING_BLOCK_DIMENSIONS_A}}, A_tile_dst, + A_src, {{LEADING_DIMENSION_A}}, A_tile_src, A_trans); + events[1].async_copy( + B_block, {{LEADING_BLOCK_DIMENSIONS_B}}, B_tile_dst, + B_src, {{LEADING_DIMENSION_B}}, B_tile_src, B_trans); + simdgroup_event::wait(2, events); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + ushort2 A_block_offset(morton_offset.x, offset_in_group.y); + ushort2 B_block_offset(offset_in_group.x, morton_offset.y); + auto A_block_src = A_block; + auto B_block_src = B_block; + A_block_src = simdgroup_matrix_storage<{{MEMORY_NAME_A}}>::apply_offset( + A_block_src, {{LEADING_BLOCK_DIMENSIONS_A}}, A_block_offset, A_trans); + B_block_src = simdgroup_matrix_storage<{{MEMORY_NAME_B}}>::apply_offset( + B_block_src, {{LEADING_BLOCK_DIMENSIONS_B}}, B_block_offset, B_trans); + + simdgroup_matrix_storage<{{REGISTER_NAME_A}}> A_sram[ + {{REGISTER_M_8}} * (K_group / 8)]; + simdgroup_matrix_storage<{{REGISTER_NAME_B}}> B_sram[ + (K_group / 8) * {{REGISTER_N_8}}]; +#pragma clang loop unroll(full) + for (ushort k = 0; k < K_remainder_padded; k += 8) { + multiply_accumulate(A_block_src, B_block_src, + A_sram, B_sram, C_sram, k); + } + + // Will there be any iterations after this one? + if (k + K_group < K) { + // If so, we haven't reached the edge of either input matrix yet. +#pragma clang loop unroll(full) + for (ushort k = K_remainder_padded; k < K_group; k += 8) { + multiply_accumulate(A_block_src, B_block_src, + A_sram, B_sram, C_sram, k); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } +} + +)"; +} diff --git a/lib/nnc/mfa/v2/GEMMKernel.hpp b/lib/nnc/mfa/v2/GEMMKernel.hpp new file mode 100644 index 000000000..7390ab0f8 --- /dev/null +++ b/lib/nnc/mfa/v2/GEMMKernel.hpp @@ -0,0 +1,73 @@ +#ifndef GEMMKernel_hpp +#define GEMMKernel_hpp + +#include "GEMMKernelDescriptor.hpp" +#include "nnc/mfa/3rdparty/metal-cpp/Metal.hpp" +#include + +class CodeWriter; + +struct GEMMKernel { + NS::SharedPtr library; + + std::string source; + + /// A copy of the block dimensions from the descriptor. + /// + /// ## C++ Adaptation + /// + /// Mapping from the Swift implementation: + /// - M -> blockDimensions[0] + /// - N -> blockDimensions[1] + /// - K -> blockDimensions[2] + simd::ushort3 blockDimensions; + + /// These properties are copied from GEMMKernelDescriptor for other helper functions to use. + simd::ushort3 leadingBlockDimensions; + + GEMMOperandPrecisions memoryPrecisions; + + GEMMOperandPrecisions registerPrecisions; + + simd::ushort2 splits; + + simd::uchar3 transposeState; + + bool preferAsyncLoad; + + bool preferAsyncStore; + + bool useBias; + + uint16_t registerM; + + uint16_t registerN; + + unsigned short threadgroupMemoryAllocation; + + /// The number of threads per group. + uint16_t threadgroupSize; + + GEMMKernel(GEMMKernelDescriptor descriptor, MTL::Device *const device); + +private: + std::string memoryName(char operand) const noexcept; + std::string registerName(char operand) const noexcept; + unsigned short threadgroupMemoryAllocationValue() const noexcept; + bool transposed(char operand) const noexcept; + std::string leadingDimension(char operand) const noexcept; + unsigned short leadingBlockDimension(char operand) const noexcept; + unsigned short trailingBlockDimension(char operand) const noexcept; + unsigned short blockBytes(char operand) const noexcept; + + std::string createSource() const noexcept; + std::string createConstants() const noexcept; + void createUtilities(CodeWriter *source) const noexcept; + void createInitializeC(CodeWriter *source) const noexcept; + void createLoadC(CodeWriter *source) const noexcept; + void createMultiplyIterations(CodeWriter *source) const noexcept; + void createStoreC(CodeWriter *source) const noexcept; +}; + +#endif /* GEMMKernel_hpp */ + diff --git a/lib/nnc/mfa/v2/GEMMKernelDescriptor.cpp b/lib/nnc/mfa/v2/GEMMKernelDescriptor.cpp new file mode 100644 index 000000000..c9458a8b5 --- /dev/null +++ b/lib/nnc/mfa/v2/GEMMKernelDescriptor.cpp @@ -0,0 +1,107 @@ +#include "GEMMKernelDescriptor.hpp" +#include "../ccv_nnc_mfa_error.hpp" +#include "../ccv_nnc_mfa_hash.hpp" + +// MARK: - Hash Conformance + +bool GEMMKernelDescriptor::operator==(const GEMMKernelDescriptor& rhs) const { + return + simd_all(blockDimensions == rhs.blockDimensions) && + memoryPrecisions == rhs.memoryPrecisions && + leadingBlockDimensions.has_value() == rhs.leadingBlockDimensions.has_value() && + simd_all(leadingBlockDimensions.value_or(simd::ushort3(UINT16_MAX)) == rhs.leadingBlockDimensions.value_or(simd::ushort3(UINT16_MAX))) && + (preferAsyncLoad == rhs.preferAsyncLoad) && + (preferAsyncStore == rhs.preferAsyncStore) && + registerPrecisions == rhs.registerPrecisions && + simd_all(splits == rhs.splits) && + simd_all(transposeState == rhs.transposeState) && + (useBias == rhs.useBias); +} + +std::size_t std::hash::operator()(const GEMMKernelDescriptor& hash) const noexcept { + std::size_t seed = 0; + using namespace ccv::nnc::mfa::hash; + combine_64(seed, pack_64(simd_make_ushort4(hash.blockDimensions, 0))); + combine_64(seed, pack_64(simd::ushort4 { hash.memoryPrecisions.A.value, hash.memoryPrecisions.B.value, hash.memoryPrecisions.C.value, hash.memoryPrecisions.bias.value })); + if (hash.leadingBlockDimensions.has_value()) { + combine_64(seed, pack_64(simd_make_ushort4(hash.leadingBlockDimensions.value()))); + } + combine_32(seed, pack_32(simd::uchar4 { hash.preferAsyncLoad, hash.preferAsyncStore, 0, 0 })); + combine_64(seed, pack_64(simd::ushort4 { hash.registerPrecisions.A.value, hash.registerPrecisions.B.value, hash.registerPrecisions.C.value, hash.registerPrecisions.bias.value })); + combine_32(seed, pack_32(hash.splits)); + combine_32(seed, pack_32(simd::uchar4 { hash.transposeState[0], hash.transposeState[1], hash.transposeState[2], hash.useBias })); + return 0; +} + +// MARK: - Initializer + +GEMMKernelDescriptor::GEMMKernelDescriptor(simd::ushort3 blockDimensions, GEMMOperandPrecisions memoryPrecisions, std::optional leadingBlockDimensions, bool preferAsyncLoad, bool preferAsyncStore, GEMMOperandPrecisions registerPrecisions, simd::ushort2 splits, simd::uchar3 transposeState, bool useBias) noexcept { + this->blockDimensions = blockDimensions; + this->memoryPrecisions = memoryPrecisions; + this->leadingBlockDimensions = leadingBlockDimensions; + this->preferAsyncLoad = preferAsyncLoad; + this->preferAsyncStore = preferAsyncStore; + this->registerPrecisions = registerPrecisions; + this->splits = splits; + this->transposeState = transposeState; + this->useBias = useBias; +} + +std::pair> GEMMKernelDescriptor::getBlockDimensions(MTL::Device* const mtlDevice, const uint32_t coreCount, const simd::uint3 matrixDimensions, const int64_t batchDimension, const GEMMOperandPrecisions memoryPrecisions, const simd::uchar3 transposeState) noexcept { + if (mtlDevice->supportsFamily(MTL::GPUFamily(1009))) { + return std::make_pair(simd::ushort3 { 32, 32, 8 }, std::nullopt); + } + + // Find the actual number of threadgroups, with a large block size. + auto ceilDivide = + [=](uint32_t target, uint16_t granularity) -> uint32_t { + return (target + uint32_t(granularity) - 1) / uint32_t(granularity); + }; + int64_t actualGroups = 1; + actualGroups *= ceilDivide(matrixDimensions[0], 48); + actualGroups *= ceilDivide(matrixDimensions[1], 48); + actualGroups *= batchDimension; + + // Does the kernel use 48x48x24xFP32 (9 KB) or 48x48x32xFP16/BF16 (6 KB)? + bool useLargeAllocation = false; + if (memoryPrecisions.A == GEMMOperandPrecision::FP32 || + memoryPrecisions.B == GEMMOperandPrecision::FP32 || + memoryPrecisions.C == GEMMOperandPrecision::FP32) { + useLargeAllocation = true; + } + + // Branch on whether the allocation is large / target occupancy is low. + if (useLargeAllocation) { + // Remove CoreCount based block size logic, per https://github.com/philipturner/ccv/commit/e8b0682b4344410eb43cdafb9a9c721ba7fdb726 + auto blockDimensions = simd::ushort3 { 48, 48, 24 }; + + // This is verified to be optimal for: + // - (memA, memB, memC) = (FP32, FP32, FP32) + // - (memA, memB, memC) = (FP16, FP16, FP32) + // - (memA, memB, memC) = (FP16, FP32, FP32) + // - (memA, memB, memC) = (FP16, FP32, FP16) + if (!transposeState[0] && !transposeState[1]) { + return std::make_pair(blockDimensions, simd::ushort3 { 24, 48, 48 }); + } else if (!transposeState[0] && transposeState[1]) { + if (memoryPrecisions.B == GEMMOperandPrecision::FP32) { + return std::make_pair(blockDimensions, simd::ushort3 { 24, 28, 48 }); + } else { + return std::make_pair(blockDimensions, simd::ushort3 { 24, 24, 48 }); + } + } else if (transposeState[0] && !transposeState[1]) { + if (memoryPrecisions.A == GEMMOperandPrecision::FP32) { + return std::make_pair(blockDimensions, simd::ushort3 { 52, 48, 48 }); + } else { + return std::make_pair(blockDimensions, simd::ushort3 { 56, 48, 48 }); + } + } else { + if (memoryPrecisions.A == GEMMOperandPrecision::FP32) { + return std::make_pair(blockDimensions, simd::ushort3 { 52, 24, 48 }); + } else { + return std::make_pair(blockDimensions, simd::ushort3 { 56, 24, 48 }); + } + } + } else { + return std::make_pair(simd::ushort3 { 48, 48, 32 }, std::nullopt); + } +} diff --git a/lib/nnc/mfa/v2/GEMMKernelDescriptor.hpp b/lib/nnc/mfa/v2/GEMMKernelDescriptor.hpp new file mode 100644 index 000000000..29154dd9b --- /dev/null +++ b/lib/nnc/mfa/v2/GEMMKernelDescriptor.hpp @@ -0,0 +1,234 @@ +#ifndef GEMMKernelDescriptor_hpp +#define GEMMKernelDescriptor_hpp + +#include "GEMMOperandPrecision.hpp" +#include "nnc/mfa/3rdparty/metal-cpp/Metal.hpp" +#include + +struct GEMMDescriptor; + +/// A configuration for a GEMM kernel. +/// +/// The information in this data structure is enough to uniquely identify the +/// kernel. It can be used as a key in a key-value cache. +/// +/// ## Usage +/// +/// The code for generating the GEMM kernel does not include any assumptions +/// about performance. It should only be responsible for correctly generating +/// a shader source, provided a configuration. The user is responsible for +/// choosing that configuration. +struct GEMMKernelDescriptor { + /// Required. The number of matrix elements spanned by each threadgroup. + /// - Parameter M: Number of output columns spanned. + /// - Parameter N: Number of output rows spanned. + /// - Parameter K: Number of loop iterations unrolled. + /// + /// Optimal values: + /// - Apple7 and Apple8: 48x48x24 + /// - Apple9 and later: 32x32x8 + /// + /// To reach optimal performance on Apple7 and Apple8, the recommended default + /// value needs to be modified conditionally. When all three operands have + /// 16-bit memory precisions, change `K` to 32. When the matrix is too small + /// to saturate all of the GPU cores, change all dimensions to 32x32x32. Even + /// smaller blocks can be exploited in low-occupancy cases, but 32x32 and + /// 48x48 are sufficient for general use. + /// + /// For simplicity or an out-of-the-box performance test, one can assume + /// occupancy is always high. But to match the performance of MPS, one must + /// optimize for small problem sizes on large GPUs. + /// + /// ## Choosing Block Size by Precision + /// + /// Legend: + /// - memA: precision for left input matrix, in memory + /// - memB: precision for right input matrix, in memory + /// - memC: precision for output matrix, in memory + /// - regA: precision for left input matrix, in registers + /// - regB: precision for right input matrix, in registers + /// - regC: precision for output matrix, in registers + /// - M1: optimal block size on Apple7 and Apple8 + /// - M3: optimal block size on Apple9 and later + /// + /// memA | memB | memC | regA | regB | regC | M1 | M3 | + /// ---- | ---- | ---- | ---- | ---- | ---- | -------- | ------- | + /// FP16 | FP16 | FP16 | any | any | any | 48x48x32 | 32x32x8 | + /// BF16 | BF16 | BF16 | any | any | any | 48x48x32 | 32x32x8 | + /// FP16 | FP16 | FP32 | any | any | any | 48x48x24 | 32x32x8 | + /// BF16 | BF16 | FP32 | any | any | any | 48x48x24 | 32x32x8 | + /// FP16 | FP32 | FP16 | any | any | any | 48x48x24 | 32x32x8 | + /// BF16 | FP32 | BF16 | any | any | any | 48x48x24 | 32x32x8 | + /// FP32 | FP32 | FP32 | any | any | any | 48x48x24 | 32x32x8 | + /// + /// ## Detecting Low-Occupancy Cases + /// + /// To determine whether the matrix saturates the GPU, divide the output + /// matrix's dimensions by 48x48. Round up to the nearest integer. Then, + /// multiply the number of row blocks by the number of column blocks. The + /// result is the number of threadgroups dispatched. For example, a C matrix + /// with dimensions 768x768 would dispatch 256 threadgroups. If you are + /// batching multiple matrix multiplications into one shader call, multiply + /// the number of threadgroups by the batch count. + /// + /// Next, calculate the target occupancy. Start by finding the GPU core count. + /// This can be accomplished in many ways; there is a heavily tested reference + /// implementation [here](https://github.com/philipturner/applegpuinfo). On + /// macOS, you can query the core count through IORegistry. On iOS, go with a + /// conservative (meaning more likely to overestimate) estimate of 5 cores on + /// A14 - A16, 10 cores on M1 - M2. + /// + /// When one of the operands is 32-bit, the target occupancy is 6 threadgroups + /// per core. When all three operands are 16-bit, the target increases to 9 + /// per core. Multiply the number of cores by the number of threadgroups per + /// core. If the total GPU occupancy is greater than or equal to the number of + /// matrix blocks, use the smaller blocking scheme. + /// + /// For example, the following decision tree would be used on an M1 Max + /// (32 cores). + /// + /// ``` + /// is device Apple9 or later? + /// yes: use block size 32x32x8 + /// no: continue decision tree [selected decision] + /// unsure: use block size 48x48x24-32 + /// + /// compute number of matrix blocks + /// 768x768 / 48x48 = 16.0 x 16.0 + /// round floating point (16.0 x 16.0) + /// to next greatest integer (16 x 16) + /// 16 x 16 x (batch size of 1) = 256 threadgroups + /// + /// compute target occupancies with 48x48 scheme + /// 32 x 6 = 192 [selected when A, B, or C is FP32] + /// 32 x 9 = 288 [selected when every matrix is FP16/BF16] + /// + /// prefer 32x32 when 48x48 has low occupancy + /// if 256 ≤ 192 + /// choose small block size (32x32x32xFP32) + /// else + /// choose large block size (48x48x24xFP32) [selected] + /// if 256 ≤ 288 + /// choose small block size (32x32x32xFP16) [selected] + /// else + /// choose large block size (48x48x32xFP16) + /// ``` + /// + /// ## C++ Adaptation + /// + /// Mapping from the Swift implementation: + /// - M -> blockDimensions[0] + /// - N -> blockDimensions[1] + /// - K -> blockDimensions[2] + simd::ushort3 blockDimensions; + + GEMMOperandPrecisions memoryPrecisions; + + /// Optional. The layout of elements in threadgroup memory. + /// + /// If not specified, the default value matches the actual block dimensions. + /// + /// This property can be used to avoid bank conflicts. For example, of one + /// operand will have 16 FP32 elements per row, there is good chance of + /// increased bank conflicts on M1. One may pad that threadgroup memory + /// allocation to 20 FP32 elements per row. + /// + /// Note that the assignment of M/N/K to row dimensions varies based on which + /// operand is discussed, and what its transpose state is. + /// + /// ## C++ Adaptation + /// + /// Mapping from the Swift implementation: + /// - A.M -> paddedBlockDimensions[0] + /// - A.K -> paddedBlockDimensions[1] + /// - B.K -> paddedBlockDimensions[2] + /// - B.N -> paddedBlockDimensions[3] + /// - C.M -> paddedBlockDimensions[4] + /// - C.N -> paddedBlockDimensions[5] + std::optional leadingBlockDimensions; + + /// Required. Whether async copies will improve performance during the + /// matrix multiplication loop. + /// + /// The default value is `true`. Async copies improve performance on Apple7 + /// and Apple8, but harm performance on Apple9 and later. However, they are + /// essential for correctness when reading from the edges of unaligned + /// matrices. Setting the value to `false` means skipping async copies when + /// doing so will not change the final result. + bool preferAsyncLoad; + + /// Required. Whether async copies will improve performance when storing the + /// accumulator to main memory. + /// + /// There is no default value that will reliably yield consistent performance. + bool preferAsyncStore; + + /// Set the register precision based on the GPU architecture, and your choice + /// for memory precision. The following set of logic statements should provide + /// optimal performance for all permutations of operand precisions. + /// + /// ``` + /// regA is identical to memA + /// regB is identical to memB + /// If memA, memB, and memC are FP16, + /// regC is FP16 + /// else + /// regC is FP32 + /// + /// If earlier than M3 + /// If memA is BF16, + /// regA is FP32 + /// If memB is BF16, + /// regB is FP32 + /// ``` + GEMMOperandPrecisions registerPrecisions; + + /// Required. The array of SIMDs to divide the threadgroup into. + /// + /// Optimal values: + /// - Apple7 and Apple8: 2x2 + /// - Apple9 and later: 1x1 + /// + /// ## C++ Adaptation + /// + /// Mapping from the Swift implementation: + /// - M -> splits[0] + /// - N -> splits[1] + simd::ushort2 splits; + + /// Required. Whether each of the inputs deviates from row-major order. + /// + /// ## C++ Adaptation + /// + /// Mapping from the Swift implementation: + /// - A -> transposeState[0] + /// - B -> transposeState[1] + /// - bias -> transposeState[2] + simd::uchar3 transposeState; + + /// Required. Whether it contains the bias. + bool useBias; + + // MARK: - Functionality from GEMMDescriptor + + GEMMKernelDescriptor() = delete; + + /// Initialize the kernel descriptor. + GEMMKernelDescriptor(simd::ushort3 blockDimensions, GEMMOperandPrecisions memoryPrecisions, std::optional paddedBlockDimensions, bool preferAsyncLoad, bool preferAsyncStore, GEMMOperandPrecisions registerPrecisions, simd::ushort2 splits, simd::uchar3 transposeState, bool useBias) noexcept; + + /// Implementation of the block size selection heuristic. + /// + /// This function initializes the 'blockDimensions' and + /// 'paddedBlockDimensions' properties. + static std::pair> getBlockDimensions(MTL::Device* const mtlDevice, const uint32_t coreCount, const simd::uint3 matrixDimensions, const int64_t batchDimension, const GEMMOperandPrecisions memoryPrecisions, const simd::uchar3 transposeState) noexcept; + + bool operator==(const GEMMKernelDescriptor& rhs) const; +}; + +template<> +struct std::hash +{ + std::size_t operator()(const GEMMKernelDescriptor& hash) const noexcept; +}; + +#endif /* GEMMKernelDescriptor_hpp */ diff --git a/lib/nnc/mfa/v2/GEMMOperandPrecision.hpp b/lib/nnc/mfa/v2/GEMMOperandPrecision.hpp new file mode 100644 index 000000000..e70606679 --- /dev/null +++ b/lib/nnc/mfa/v2/GEMMOperandPrecision.hpp @@ -0,0 +1,89 @@ +#ifndef GEMMOperandPrecision_hpp +#define GEMMOperandPrecision_hpp + +#include +#include + +/// An enumeration of the precisions supported by the kernel. +/// +/// If you wish to support quantized precisions, copy/translate the source code +/// and integrate a modified version into your app. Something similar to a Swift +/// `enum` (e.g. C++ `enum class`) could enumerate the quantization formats +/// used by application code. An exemplary set could be: +/// - FP32 +/// - FP16 +/// - BF16 +/// - signed 8-bit integer +/// - s1ezm7 +/// - FP8 +/// - palletized +/// +/// If you support non-floating-point formats, you have the responsibility of +/// authoring correct and performant GPU code for them. A general rule of thumb, +/// is keep the data compressed in `device` or `threadgroup` memory. Transform +/// into a floating point type while loading into the registers. Keep the +/// accumulator in floating point until the output needs to be written. +/// If the output is quantized, it will be compressed when writing back to +/// `device` memory (or `threadgroup` before the async copy in edge cases). +/// +/// For example, the reference implementation treats BF16 like a quantized +/// integer type on Apple7 and Apple8 GPUs. It is decompressed to FP32 in +/// registers. +class GEMMOperandPrecision { + // Hijack some C++ syntax, making it look like Swift's enumerations with + // member functions. + // + // Source: https://stackoverflow.com/a/53284026 +public: + enum Value: uint16_t { + FP32 = 0, + FP16 = 1, + BF16 = 2, + }; + + GEMMOperandPrecision() = default; + constexpr GEMMOperandPrecision(Value aPrecision) : value(aPrecision) { } + + // Prevent usage: if(precision) + explicit operator bool() const = delete; + + constexpr bool operator==(const GEMMOperandPrecision &rhs) const { return value == rhs.value; } + constexpr bool operator!=(const GEMMOperandPrecision &rhs) const { return value != rhs.value; } + + // The MSL keyword corresponding to the precision. + std::string name() const noexcept { + switch (value) { + case FP32: + return "float"; + case FP16: + return "half"; + case BF16: + return "bfloat"; + } + } + + // The size of the scalar, in bytes. + int64_t size() const noexcept { + switch (value) { + case FP32: + return 4; + case FP16: + return 2; + case BF16: + return 2; + } + } + + Value value; +}; + +/// A way to emulate the API of the Swift tuple with labeled members. +struct GEMMOperandPrecisions { + GEMMOperandPrecision A; + GEMMOperandPrecision B; + GEMMOperandPrecision C; + GEMMOperandPrecision bias; + constexpr bool operator==(const GEMMOperandPrecisions& rhs) const { return A == rhs.A && B == rhs.B && C == rhs.C && bias == rhs.bias; } +}; + +#endif /* GEMMOperandPrecision_hpp */ diff --git a/lib/nnc/mfa/v2/PipelineValue.hpp b/lib/nnc/mfa/v2/PipelineValue.hpp new file mode 100644 index 000000000..f83de6112 --- /dev/null +++ b/lib/nnc/mfa/v2/PipelineValue.hpp @@ -0,0 +1,12 @@ +#ifndef MFA_PIPELINE_VALUE_HPP_ +#define MFA_PIPELINE_VALUE_HPP_ + +#include "nnc/mfa/3rdparty/metal-cpp/Metal.hpp" + +template +struct PipelineValue { + T* kernel; + NS::SharedPtr pipeline; +}; + +#endif diff --git a/lib/nnc/mfa/v2/ShaderCache.hpp b/lib/nnc/mfa/v2/ShaderCache.hpp new file mode 100644 index 000000000..ae312ce04 --- /dev/null +++ b/lib/nnc/mfa/v2/ShaderCache.hpp @@ -0,0 +1,79 @@ +#ifndef MFA_SHADER_CACHE_HPP_ +#define MFA_SHADER_CACHE_HPP_ + +#include "nnc/mfa/3rdparty/metal-cpp/Metal.hpp" +#include "PipelineValue.hpp" +#include "DeviceProperties.hpp" + +using TypeInfoRef = std::reference_wrapper; + +struct Hasher { + std::size_t operator()(TypeInfoRef code) const { + return code.get().hash_code(); + } +}; + +struct EqualTo { + bool operator()(TypeInfoRef lhs, TypeInfoRef rhs) const { + return lhs.get() == rhs.get(); + } +}; + +struct TypeErasedUnorderedMap { + virtual ~TypeErasedUnorderedMap() = default; +}; + +template +struct UnorderedMapWrapper: public TypeErasedUnorderedMap { + std::unordered_map map; +}; + +/// A reference implementation of shader caching. +/// +/// One good design for a shader caching mechanism: +/// - Two key-value caches. +/// - The first caches `MTLLibrary` objects. +/// - Large latency +/// - Small number of combinatorial possibilities, likely to be shared by +/// matrices with a different size. +/// - Don't bother with serializing Metal binary archives to disk. You are +/// already utilizing the system-wide Metal shader cache. +/// - The second caches `MTLComputePipelineState` objects. +/// - Instantiations of the `MTLLibrary` with different function constants. +/// - Less latency than compiling from source, but still non-negligible. You +/// can't spawn a new PSO during every call to a matrix multiplication. +struct ShaderCache { +private: + /// WARNING: Not thread safe. But will the DSL interpreter even use + /// multithreading? + std::unordered_map, Hasher, EqualTo> libraryCache; + + /// WARNING: Not thread safe. But will the DSL interpreter even use + /// multithreading? + std::unordered_map, Hasher, EqualTo> pipelineCache; +public: + /// Implementation of the logic for choosing between 'device' and + /// 'threadgroup' store. + /// + /// ## C++ Adaptation + /// + /// Wrap every call to this function in an autoreleasepool. + template + PipelineValue* findKernel(Descriptor descriptor, MTL::Device *const device, const DeviceProperties &dprops) noexcept { + UnorderedMapWrapper>> *pipelineCache = static_cast>> *>(this->pipelineCache.try_emplace(typeid(Descriptor), std::make_unique>>>()).first->second.get()); + auto iterator = pipelineCache->map.find(descriptor); + if (iterator != pipelineCache->map.end()) { + return iterator->second.get(); + } + UnorderedMapWrapper> *libraryCache = static_cast> *>(this->libraryCache.try_emplace(typeid(KernelDescriptor), std::make_unique>>()).first->second.get()); + auto result = descriptor.findKernel(device, dprops, &libraryCache->map); + pipelineCache->map[descriptor] = std::unique_ptr>(result.second); + return result.second; + } + + void evict() noexcept { + pipelineCache.clear(); + } +}; + +#endif diff --git a/lib/nnc/mps/ccv_nnc_mps.m b/lib/nnc/mps/ccv_nnc_mps.m index 1f7f2a9c7..de817b65f 100644 --- a/lib/nnc/mps/ccv_nnc_mps.m +++ b/lib/nnc/mps/ccv_nnc_mps.m @@ -427,6 +427,7 @@ static inline void ccv_nnc_mps_graph_key_free(ccv_nnc_mps_graph_key_t key) void ccv_nnc_mps_clear_graph_executable_cache(void) { + ccv_nnc_mfa_clear_pipeline_cache(ccv_nnc_default_mfa_context()); if (!g_graph_executable_cache) return; khiter_t k;