From 6c6ade0ed74fbed021f9bc9f991a6e2d069f2d19 Mon Sep 17 00:00:00 2001 From: Liu Liu Date: Sun, 4 Aug 2024 19:20:38 -0400 Subject: [PATCH] Cover all shapes for add / cast. --- lib/nnc/cmd/blas/mps/ccv_nnc_add_mps.m | 4 -- lib/nnc/cmd/util/mps/ccv_nnc_util_mps.m | 2 +- lib/nnc/mfa/ccv_nnc_mfa_add.cpp | 61 +++++++++++++++++++++--- lib/nnc/mfa/ccv_nnc_mfa_cast.cpp | 62 +++++++++++++++++++++++-- 4 files changed, 113 insertions(+), 16 deletions(-) diff --git a/lib/nnc/cmd/blas/mps/ccv_nnc_add_mps.m b/lib/nnc/cmd/blas/mps/ccv_nnc_add_mps.m index b5c6be971..4ae5258f8 100644 --- a/lib/nnc/cmd/blas/mps/ccv_nnc_add_mps.m +++ b/lib/nnc/cmd/blas/mps/ccv_nnc_add_mps.m @@ -122,10 +122,6 @@ static int _ccv_nnc_add_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, use_mfa = false; fallback_reason = "Broadcast semantics unsupported."; } - if (length % 4 != 0) { - use_mfa = false; - fallback_reason = "Length cannot divide by 4."; - } } if (use_mfa) { if (!CCV_IS_TENSOR_CONTIGUOUS(a) || !CCV_IS_TENSOR_CONTIGUOUS(b) || !CCV_IS_TENSOR_CONTIGUOUS(c)) { diff --git a/lib/nnc/cmd/util/mps/ccv_nnc_util_mps.m b/lib/nnc/cmd/util/mps/ccv_nnc_util_mps.m index 005c80e88..24b71d85a 100644 --- a/lib/nnc/cmd/util/mps/ccv_nnc_util_mps.m +++ b/lib/nnc/cmd/util/mps/ccv_nnc_util_mps.m @@ -362,7 +362,7 @@ static int _ccv_nnc_datatype_conversion(const ccv_nnc_cmd_t cmd, const ccv_nnc_h assert(output_size <= input_size); int i; @autoreleasepool { - bool use_mfa = false; + bool use_mfa = true; const char *fallback_reason = NULL; ccv_nnc_mfa_context_t* context = ccv_nnc_default_mfa_context(); diff --git a/lib/nnc/mfa/ccv_nnc_mfa_add.cpp b/lib/nnc/mfa/ccv_nnc_mfa_add.cpp index f27858e67..5d0714044 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa_add.cpp +++ b/lib/nnc/mfa/ccv_nnc_mfa_add.cpp @@ -73,7 +73,26 @@ mfa::add::pipeline::pipeline(mfa::context* context, mfa::add::hash hash) { auto* pool = NS::AutoreleasePool::alloc()->init(); - std::string shader = R"( + std::string shader; + // In this case, we can igore the boundary check. + if (hash.length % (4 * 256) == 0) { + shader = R"( +#include +using namespace metal; + +kernel void add( + device const real4 *src0 [[buffer(0)]], + device const real4 *src1 [[buffer(1)]], + device real4 *dst [[buffer(2)]], + + uint3 tpig [[thread_position_in_grid]] +) { + const uint idx = tpig.x; + dst[idx] = src0[idx] + src1[idx]; +} + )"; + } else if (hash.length % 4 == 0) { + shader = R"( #include using namespace metal; @@ -90,21 +109,51 @@ kernel void add( dst[idx] = src0[idx] + src1[idx]; } )"; + } else { + shader = R"( +#include +using namespace metal; + +kernel void add( + device const real *src0 [[buffer(0)]], + device const real *src1 [[buffer(1)]], + device real *dst [[buffer(2)]], + + uint3 tpig [[thread_position_in_grid]] +) { + const uint idx = tpig.x; + if (idx >= count) + return; + dst[idx] = src0[idx] + src1[idx]; +} + )"; + } std::string defines = ""; if (hash.data_type == MTL::DataTypeFloat) { defines += std::string("typedef float4 real4;"); defines += "\n"; + defines += std::string("typedef float real;"); + defines += "\n"; } else { defines += std::string("typedef half4 real4;"); defines += "\n"; + defines += std::string("typedef half real;"); + defines += "\n"; } - defines += "constant uint count = "; - CCV_NNC_MFA_PRECONDITION(hash.length % 4 == 0) - const unsigned int count = hash.length / 4; - defines += std::to_string(count) + ";"; - defines += "\n"; + unsigned int count; + if (hash.length % 4 == 0) { + count = hash.length / 4; + } else { + count = hash.length; + } + // Only boundary check needs this const in the shader. + if (hash.length % (4 * 256) != 0) { + defines += "constant uint count = "; + defines += std::to_string(count) + ";"; + defines += "\n"; + } this->group_size = MTL::Size(256, 1, 1); const int num_blocks = (count + 255) / 256; this->grid_size = MTL::Size(num_blocks, 1, 1); diff --git a/lib/nnc/mfa/ccv_nnc_mfa_cast.cpp b/lib/nnc/mfa/ccv_nnc_mfa_cast.cpp index cf56b769d..c583ee46e 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa_cast.cpp +++ b/lib/nnc/mfa/ccv_nnc_mfa_cast.cpp @@ -79,7 +79,42 @@ mfa::cast::pipeline::pipeline(mfa::context* context, mfa::cast::hash hash) { auto* pool = NS::AutoreleasePool::alloc()->init(); - std::string shader = R"( + std::string shader; + // In this case, we can igore the boundary check. + if (hash.length % (4 * 256) == 0) { + shader = R"( +#include +using namespace metal; + +kernel void cast( + device original_real4 *src [[buffer(0)]], + device real4 *destination [[buffer(1)]], + + uint3 tpig [[thread_position_in_grid]] +) { + const uint idx = tpig.x; + destination[idx] = (real4)(src[idx]); +} + )"; + } else if (hash.length % 4 == 0) { + shader = R"( +#include +using namespace metal; + +kernel void cast( + device original_real4 *src [[buffer(0)]], + device real4 *destination [[buffer(1)]], + + uint3 tpig [[thread_position_in_grid]] +) { + const uint idx = tpig.x; + if (idx >= count) + return; + destination[idx] = (real4)(src[idx]); +} + )"; + } else { + shader = R"( #include using namespace metal; @@ -95,29 +130,46 @@ kernel void cast( destination[idx] = (real)(src[idx]); } )"; + } std::string defines = ""; if (hash.data_type == MTL::DataTypeFloat) { defines += std::string("typedef float real;"); defines += "\n"; + defines += std::string("typedef float4 real4;"); + defines += "\n"; } else { defines += std::string("typedef half real;"); defines += "\n"; + defines += std::string("typedef half4 real4;"); + defines += "\n"; } if (hash.original_data_type == MTL::DataTypeFloat) { defines += std::string("typedef float original_real;"); defines += "\n"; + defines += std::string("typedef float4 original_real4;"); + defines += "\n"; } else { defines += std::string("typedef half original_real;"); defines += "\n"; + defines += std::string("typedef half4 original_real4;"); + defines += "\n"; } - defines += "constant uint count = "; - defines += std::to_string(hash.length) + ";"; - defines += "\n"; + unsigned int count; + if (hash.length % 4 == 0) { + count = hash.length / 4; + } else { + count = hash.length; + } + if (hash.length % (4 * 256) != 0) { + defines += "constant uint count = "; + defines += std::to_string(count) + ";"; + defines += "\n"; + } this->group_size = MTL::Size(256, 1, 1); - const int num_blocks = (hash.length + 255) / 256; + const int num_blocks = (count + 255) / 256; this->grid_size = MTL::Size(num_blocks, 1, 1); auto constants = NS::TransferPtr(MTL::FunctionConstantValues::alloc()->init());