diff --git a/lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m b/lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m index 4932d027c..8bebd9d48 100644 --- a/lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m +++ b/lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m @@ -173,8 +173,9 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint } ccv_nnc_mfa_context_t* context = ccv_nnc_default_mfa_context(); + const int is_mfa_gemv = !is_batched && ((a_rows == 1 && is_transpose_w && (w_rows % 4) == 0) || (!is_transpose_a && w_cols == 1 && (a_cols % 4) == 0)); const int is_mfa_supported = - ccv_nnc_mfa_context_supported(context) && is_contiguous && is_same_dtype && is_supported_dtype && (!is_batched || is_mfa_compatible_batch) && !(ccv_nnc_flags() & CCV_NNC_DISABLE_METAL_FLASH_ATTENTION) && !(ccv_nnc_flags() & CCV_NNC_DISABLE_MFA_GEMM); + ccv_nnc_mfa_context_supported(context) && is_contiguous && is_same_dtype && is_supported_dtype && (!is_batched || is_mfa_compatible_batch) && !(ccv_nnc_flags() & CCV_NNC_DISABLE_METAL_FLASH_ATTENTION) && (is_mfa_gemv || !(ccv_nnc_flags() & CCV_NNC_DISABLE_MFA_GEMM)); size_t a_data_size = 0; if (CCV_GET_DATA_TYPE(a->info.datatype) == CCV_QX) @@ -266,6 +267,91 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint w_data = scratch; w_dataof = a_data_size; } + if (is_mfa_gemv) + { + // This is GEMV, use GEMV kernel. + ccv_nnc_mfa_gemv_params_t params; + if (a_rows == 1 && is_transpose_w) + { + params = (ccv_nnc_mfa_gemv_params_t){ + .data_type = mtl_data_type, + .ncols = w_rows, + .nrows = w_cols, + .fused_bias = bias ? 1 : 0, + }; + } else { + params = (ccv_nnc_mfa_gemv_params_t){ + .data_type = mtl_data_type, + .ncols = a_cols, + .nrows = a_rows, + .fused_bias = bias ? 1 : 0, + }; + } + ccv_nnc_mfa_prepare_gemv(context, params); + + // Creating a new command buffer has a >10 µs penalty CPU-side. Still + // faster the >50 µs penalty for MPSGraph (probably why + // MPSMatrixMultiplication is faster for GEMM). + mtl_command_batch_t* command_batch = ccv_nnc_stream_context_start_command_batch(stream_context); + if (CCV_GET_DATA_TYPE(a->info.datatype) == CCV_QX) + { + mtl_buffer_t* tensors[3] = { + mpgetbuffer((ccv_nnc_tensor_t*)a), // A + (mtl_buffer_t*)scratch, // B + NULL, + }; + size_t tensor_offsets[2] = { + a->dataof, // A offset + 0, // B offset + }; + ccv_nnc_mfa_encode_depalettize(context, a_depalettize_params, command_batch, tensors, tensor_offsets); + } + if (CCV_GET_DATA_TYPE(w->info.datatype) == CCV_QX) + { + mtl_buffer_t* tensors[3] = { + mpgetbuffer((ccv_nnc_tensor_t*)w), // A + (mtl_buffer_t*)scratch, // B + NULL, + }; + size_t tensor_offsets[2] = { + w->dataof, // A offset + a_data_size, // B offset + }; + ccv_nnc_mfa_encode_depalettize(context, w_depalettize_params, command_batch, tensors, tensor_offsets); + } + mtl_buffer_t* bias_buffer = NULL; + if (bias) { + bias_buffer = mpgetbuffer((ccv_nnc_tensor_t*)bias); + } + mtl_buffer_t* tensors[5] = { + NULL, + NULL, + mpgetbuffer((ccv_nnc_tensor_t*)b), // C + bias_buffer, // D + NULL, + }; + size_t tensor_offsets[4] = { + 0, + 0, + b->dataof, // C offset + bias ? bias->dataof : 0, // D offset + }; + if (a_rows == 1 && is_transpose_w) + { + tensors[0] = w_data; + tensors[1] = a_data; + tensor_offsets[0] = w_dataof; + tensor_offsets[1] = a_dataof; + } else { + tensors[0] = a_data; + tensors[1] = w_data; + tensor_offsets[0] = a_dataof; + tensor_offsets[1] = w_dataof; + } + ccv_nnc_mfa_encode_gemv(context, params, command_batch, tensors, tensor_offsets); + ccv_nnc_stream_context_finish_command_batch(stream_context, command_batch); + return CCV_NNC_EXEC_SUCCESS; + } // On supported devices, use Metal directly. ccv_nnc_mfa_gemm_params_t params = { .data_type = mtl_data_type, diff --git a/lib/nnc/mfa/ccv_nnc_mfa.cpp b/lib/nnc/mfa/ccv_nnc_mfa.cpp index 4303d533c..08b90113f 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa.cpp +++ b/lib/nnc/mfa/ccv_nnc_mfa.cpp @@ -116,6 +116,12 @@ void mfa::cache::prepare(mfa::context* con _mfa_cache_prepare(&map, context, hash); } +template <> +void mfa::cache::prepare(mfa::context* context, mfa::gemv::hash hash) +{ + _mfa_cache_prepare(&map, context, hash); +} + mfa::context::context(MTL::Device* device) { auto* pool = NS::AutoreleasePool::alloc()->init(); diff --git a/lib/nnc/mfa/ccv_nnc_mfa.hpp b/lib/nnc/mfa/ccv_nnc_mfa.hpp index ab528b4c4..461d9401d 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa.hpp +++ b/lib/nnc/mfa/ccv_nnc_mfa.hpp @@ -9,6 +9,7 @@ #include "ccv_nnc_mfa_depalettize.hpp" #include "ccv_nnc_mfa_adam.hpp" #include "ccv_nnc_mfa_cmul.hpp" +#include "ccv_nnc_mfa_gemv.hpp" #ifdef __cplusplus #include "nnc/mfa/3rdparty/metal-cpp/Dispatch.hpp" @@ -50,6 +51,7 @@ class context { cache depalettize_cache; cache adam_cache; cache cmul_cache; + cache gemv_cache; MTL::Buffer* request_scratch(uint64_t size); }; diff --git a/lib/nnc/mfa/ccv_nnc_mfa_gemv.cpp b/lib/nnc/mfa/ccv_nnc_mfa_gemv.cpp new file mode 100644 index 000000000..f6a6a2985 --- /dev/null +++ b/lib/nnc/mfa/ccv_nnc_mfa_gemv.cpp @@ -0,0 +1,265 @@ +#include "ccv_nnc_mfa.hpp" +#include "ccv_nnc_mfa_hash.hpp" +#include +using namespace ccv::nnc; + +#include + +// MARK: - C + +void ccv_nnc_mfa_prepare_gemv(mfa::context* context, ccv_nnc_mfa_gemv_params_t params) +{ + context->gemv_cache.prepare(context, mfa::gemv::hash(params)); +} + +void ccv_nnc_mfa_encode_gemv(ccv_nnc_mfa_context_t* context, ccv_nnc_mfa_gemv_params_t params, mtl_command_batch_t* command_batch, mtl_buffer_t** tensors, size_t* tensor_offsets) +{ + mfa::gemv::hash hash(params); + auto iterator = context->gemv_cache.map.find(hash); + if (iterator == context->gemv_cache.map.end()) { + mfa::precondition_failure("gemv hash not cached.", __LINE__, __FILE__, __FUNCTION__); + } + + auto* pipeline = iterator->second; + auto encoder = command_batch->startCommand(); + + int num_tensors = 0; + while (tensors[num_tensors] != nullptr) { + encoder->setBuffer(tensors[num_tensors], tensor_offsets[num_tensors], NS::UInteger(num_tensors)); + num_tensors += 1; + } + CCV_NNC_MFA_PRECONDITION(num_tensors == 3 || num_tensors == 4); + + encoder->setComputePipelineState(pipeline->gemv_pso.get()); + encoder->useResource(tensors[0], MTL::ResourceUsageRead); + encoder->useResource(tensors[1], MTL::ResourceUsageRead); + encoder->useResource(tensors[2], MTL::ResourceUsageWrite); + if (num_tensors == 4) { + encoder->useResource(tensors[3], MTL::ResourceUsageRead); + } + + auto grid_size = pipeline->grid_size; + CCV_NNC_MFA_PRECONDITION(grid_size.depth > 0); + encoder->dispatchThreadgroups(grid_size, pipeline->group_size); + command_batch->finishCommand(encoder); +} + +// MARK: - C++ + +mfa::gemv::hash::hash(ccv_nnc_mfa_gemv_params_t params) { + data_type = params.data_type; + nrows = params.nrows; + ncols = params.ncols; + fused_bias = params.fused_bias; +} + +bool mfa::gemv::hash::operator==(const mfa::gemv::hash& hash) const { + return + (data_type == hash.data_type) && + (nrows == hash.nrows) && + (ncols == hash.ncols) && + (fused_bias == hash.fused_bias); +} + +std::ostream& operator<<(std::ostream& os, const mfa::gemv::hash& hash) { + os << "mfa::gemv::hash {"; + os << " .data_type = " << hash.data_type << ','; + os << " .nrows = " << hash.nrows << ','; + os << " .ncols = " << hash.ncols << ','; + os << " .fused_bias = " << bool(hash.fused_bias) << " "; + os << "}"; + return os; +} + +std::size_t std::hash::operator()(const mfa::gemv::hash& hash) const noexcept { + std::size_t seed = 0; + using namespace mfa::hash; + combine_64(seed, hash.data_type); + combine_64(seed, pack_64(simd::uint2 { (unsigned int)hash.nrows, (unsigned int)hash.ncols })); + combine_32(seed, pack_32(simd::uchar4 { hash.fused_bias, 0, 0, 0 })); + return seed; +} + +mfa::gemv::pipeline::pipeline(mfa::context* context, mfa::gemv::hash hash) { + // FlashNorm not supported for group gemv yet. + CCV_NNC_MFA_PRECONDITION((hash.data_type == MTL::DataTypeFloat) || (hash.data_type == MTL::DataTypeHalf)) + + auto* pool = NS::AutoreleasePool::alloc()->init(); + + std::string shader; + if (hash.fused_bias) { + shader = R"( +#include +using namespace metal; + +kernel void gemv( + device const real *src0 [[buffer(0)]], + device const real *src1 [[buffer(1)]], + device real *dst [[buffer(2)]], + device const real *bias [[buffer(3)]], + + uint tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int64_t rb = tgpig * N; + device const real* y = (device const real*)src1; + + if (ncols < 128) { + for (uint row = 0; row < N; ++row) { + uint r1 = rb + row; + if (r1 >= nrows) { + break; + } + device const real* x = (device const real*)src0 + r1 * ncols; + float sumf = 0; + for (uint i = tiisg; i < ncols; i += 32) { + sumf += (real)x[i] * (real)y[i]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[r1] = bias[r1] + all_sum; + } + } + } else { + device const real4* y4 = (device const real4*)y; + for (uint row = 0; row < N; ++row) { + uint r1 = rb + row; + if (r1 >= nrows) { + break; + } + + device const real* x = (device const real*)src0 + r1 * ncols; + device const real4* x4 = (device const real4*)x; + + float sumf = 0; + for (uint i = tiisg; i < ncols / 4; i += 32) { + sumf += (real)x4[i][0] * y4[i][0]; + sumf += (real)x4[i][1] * y4[i][1]; + sumf += (real)x4[i][2] * y4[i][2]; + sumf += (real)x4[i][3] * y4[i][3]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[r1] = bias[r1] + all_sum; + } + } + } +} + )"; + } else { + shader = R"( +#include +using namespace metal; + +kernel void gemv( + device const real *src0 [[buffer(0)]], + device const real *src1 [[buffer(1)]], + device real *dst [[buffer(2)]], + + uint tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int64_t rb = tgpig * N; + device const real* y = (device const real*)src1; + + if (ncols < 128) { + for (uint row = 0; row < N; ++row) { + uint r1 = rb + row; + if (r1 >= nrows) { + break; + } + device const real* x = (device const real*)src0 + r1 * ncols; + float sumf = 0; + for (uint i = tiisg; i < ncols; i += 32) { + sumf += (real)x[i] * (real)y[i]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[r1] = all_sum; + } + } + } else { + device const real4* y4 = (device const real4*)y; + for (uint row = 0; row < N; ++row) { + uint r1 = rb + row; + if (r1 >= nrows) { + break; + } + + device const real* x = (device const real*)src0 + r1 * ncols; + device const real4* x4 = (device const real4*)x; + + float sumf = 0; + for (uint i = tiisg; i < ncols / 4; i += 32) { + sumf += (real)x4[i][0] * y4[i][0]; + sumf += (real)x4[i][1] * y4[i][1]; + sumf += (real)x4[i][2] * y4[i][2]; + sumf += (real)x4[i][3] * y4[i][3]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[r1] = all_sum; + } + } + } +} + )"; + } + + std::string defines = ""; + if (hash.data_type == MTL::DataTypeFloat) { + defines += std::string("typedef float real;"); + defines += "\n"; + defines += std::string("typedef float4 real4;"); + defines += "\n"; + } else { + defines += std::string("typedef half real;"); + defines += "\n"; + defines += std::string("typedef half4 real4;"); + defines += "\n"; + } + + defines += "constant uint N = 8;\n"; + defines += "constant uint ncols = "; + defines += std::to_string(hash.ncols) + ";"; + defines += "\n"; + defines += "constant uint nrows = "; + defines += std::to_string(hash.nrows) + ";"; + defines += "\n"; + this->group_size = MTL::Size(32, 1, 1); + this->grid_size = MTL::Size((hash.nrows + 8 - 1) / 8, 1, 1); + + auto constants = NS::TransferPtr(MTL::FunctionConstantValues::alloc()->init()); + NS::SharedPtr* pso = &gemv_pso; + + std::string source = defines; + if (METAL_LOG_LEVEL(context) >= 4) { + std::cerr << source << std::endl; + } + source += shader; + + NS::Error *error = nullptr; + auto swift_source = NS::String::string(source.c_str(), + NS::UTF8StringEncoding); + auto library = NS::TransferPtr(context->device->newLibrary(swift_source, nullptr, &error)); + if (!library) { + CCV_NNC_MFA_CHECK_ERROR(error) + } + + auto swift_name = NS::String::string("gemv", NS::UTF8StringEncoding); + auto function = NS::TransferPtr(library->newFunction(swift_name, constants.get(), &error)); + if (!function) { + CCV_NNC_MFA_CHECK_ERROR(error) + } + + *pso = NS::TransferPtr(context->device->newComputePipelineState(function.get(), &error)); + if (!*pso) { + CCV_NNC_MFA_CHECK_ERROR(error) + } + + pool->drain(); +} diff --git a/lib/nnc/mfa/ccv_nnc_mfa_gemv.hpp b/lib/nnc/mfa/ccv_nnc_mfa_gemv.hpp new file mode 100644 index 000000000..12097da32 --- /dev/null +++ b/lib/nnc/mfa/ccv_nnc_mfa_gemv.hpp @@ -0,0 +1,66 @@ +#ifndef GUARD_ccv_nnc_mfa_gemv_hpp +#define GUARD_ccv_nnc_mfa_gemv_hpp + +typedef struct { + uint64_t data_type; + uint32_t nrows; + uint32_t ncols; + uint8_t fused_bias; +} ccv_nnc_mfa_gemv_params_t; + +#ifdef __cplusplus +#include "nnc/mfa/3rdparty/metal-cpp/Dispatch.hpp" +#include "nnc/mfa/3rdparty/metal-cpp/Metal.hpp" +#include + +namespace ccv { +namespace nnc { +namespace mfa { +namespace gemv { + +class hash { +public: + uint64_t data_type; + uint32_t nrows; + uint32_t ncols; + uint8_t fused_bias; + + hash(ccv_nnc_mfa_gemv_params_t); + + bool operator==(const hash& rhs) const; +}; + +class pipeline { +public: + NS::SharedPtr gemv_pso; + + MTL::Size grid_size; + MTL::Size group_size; + + pipeline(context* context, hash hash); +}; + +} // namespace gemv +} // namespace mfa +} // namespace nnc +} // namespace ccv + +std::ostream& operator<<(std::ostream& os, const ccv::nnc::mfa::gemv::hash& hash); + +template<> +struct std::hash +{ + std::size_t operator()(const ccv::nnc::mfa::gemv::hash& hash) const noexcept; +}; + +extern "C" { +#endif // __cplusplus + +void ccv_nnc_mfa_prepare_gemv(ccv_nnc_mfa_context_t* context, ccv_nnc_mfa_gemv_params_t params); +void ccv_nnc_mfa_encode_gemv(ccv_nnc_mfa_context_t* context, ccv_nnc_mfa_gemv_params_t params, mtl_command_batch_t* command_batch, mtl_buffer_t** tensors, size_t* tensor_offsets); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif diff --git a/lib/nnc/mfa/makefile b/lib/nnc/mfa/makefile index 57a116481..860b16834 100644 --- a/lib/nnc/mfa/makefile +++ b/lib/nnc/mfa/makefile @@ -2,7 +2,7 @@ include ../../config.mk CFLAGS := -std=c++17 -O3 -Wall -I"../../" $(CFLAGS) -SRCS := Metal.cpp ccv_nnc_mfa.cpp ccv_nnc_mfa_attention.cpp ccv_nnc_mfa_error.cpp ccv_nnc_mfa_gemm.cpp ccv_nnc_mfa_normalization.cpp ccv_nnc_mfa_depalettize.cpp ccv_nnc_mfa_adam.cpp ccv_nnc_mfa_cmul.cpp 3rdparty/metal-cpp/Dispatch.cpp +SRCS := Metal.cpp ccv_nnc_mfa.cpp ccv_nnc_mfa_attention.cpp ccv_nnc_mfa_error.cpp ccv_nnc_mfa_gemm.cpp ccv_nnc_mfa_normalization.cpp ccv_nnc_mfa_depalettize.cpp ccv_nnc_mfa_adam.cpp ccv_nnc_mfa_cmul.cpp ccv_nnc_mfa_gemv.cpp 3rdparty/metal-cpp/Dispatch.cpp SRC_OBJS := $(patsubst %.c,%.o,$(patsubst %.cpp,%.o,$(SRCS))) diff --git a/test/int/nnc/mpsblas.tests.c b/test/int/nnc/mpsblas.tests.c index 7be3c03a8..cd92cfe42 100644 --- a/test/int/nnc/mpsblas.tests.c +++ b/test/int/nnc/mpsblas.tests.c @@ -602,6 +602,58 @@ TEST_CASE("mps forward gemm in half precision") ccv_nnc_tensor_free(hbias2); } +TEST_CASE("mps forward gemv in half precision, variant 1") +{ + GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_GEMM_FORWARD, CCV_NNC_BACKEND_MPS)); + dsfmt_t dsfmt; + dsfmt_init_gen_rand(&dsfmt, 0); + ccv_nnc_tensor_t* a = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, 1, 128), 0); + ccv_nnc_tensor_t* w = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, 64, 128), 0); + ccv_nnc_tensor_t* bias = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, 64), 0); + ccv_nnc_tensor_t* b = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, 1, 64), 0); + + ccv_nnc_tensor_t* ha = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 1, 128), 0); + ccv_nnc_tensor_t* hw = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 64, 128), 0); + ccv_nnc_tensor_t* hbias = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 64), 0); + ccv_nnc_tensor_t* hb = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 1, 64), 0); + int i; + for (i = 0; i < 64 * 128; i++) + hw->data.f32[i] = dsfmt_genrand_open_close(&dsfmt) / (64 * 128); + for (i = 0; i < 64; i++) + hbias->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); + ccv_nnc_tensor_t* ha1 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 1, 128), 0); + for (i = 0; i < 128; i++) + ha1->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); + for (i = 0; i < 128; i++) + ha->data.f32[i] = ha1->data.f32[i]; + ccv_nnc_tensor_t* ha2 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, 1, 128), 0); + ccv_nnc_tensor_t* hw2 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, 64, 128), 0); + ccv_nnc_tensor_t* hbias2 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, 64), 0); + ccv_nnc_cmd_exec(CMD_DATATYPE_CONVERSION_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(ha1, hw, hbias), TENSOR_LIST(ha2, hw2, hbias2), 0); + ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(ha2, hw2, hbias2), TENSOR_LIST(a, w, bias), 0); + ccv_nnc_cmd_exec(CMD_GEMM_FORWARD(NO_TRANSPOSE, TRANSPOSE(0, 1)), ccv_nnc_no_hint, 0, TENSOR_LIST(ha, hw, hbias), TENSOR_LIST(hb), 0); + ccv_nnc_cmd_exec(CMD_GEMM_FORWARD(NO_TRANSPOSE, TRANSPOSE(0, 1)), ccv_nnc_no_hint, 0, TENSOR_LIST(a, w, bias), TENSOR_LIST(b), 0); + ccv_nnc_tensor_t* tb = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, 1, 64), 0); + ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(b), TENSOR_LIST(tb), 0); + ccv_nnc_tensor_t* tb1 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 1, 64), 0); + ccv_nnc_cmd_exec(CMD_DATATYPE_CONVERSION_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(tb), TENSOR_LIST(tb1), 0); + REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, tb1->data.f32, hb->data.f32, 64, 1e-3, "GPU computed output should be the same as CPU computed ones"); + ccv_nnc_tensor_free(a); + ccv_nnc_tensor_free(w); + ccv_nnc_tensor_free(bias); + ccv_nnc_tensor_free(b); + ccv_nnc_tensor_free(tb); + ccv_nnc_tensor_free(ha); + ccv_nnc_tensor_free(ha1); + ccv_nnc_tensor_free(tb1); + ccv_nnc_tensor_free(hw); + ccv_nnc_tensor_free(hbias); + ccv_nnc_tensor_free(hb); + ccv_nnc_tensor_free(ha2); + ccv_nnc_tensor_free(hw2); + ccv_nnc_tensor_free(hbias2); +} + TEST_CASE("mps forward gemm no bias") { GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_GEMM_FORWARD, CCV_NNC_BACKEND_MPS)); @@ -686,6 +738,94 @@ TEST_CASE("mps forward gemm no bias in half precision") ccv_nnc_tensor_free(hw2); } +TEST_CASE("mps forward gemv in half precision no bias, variant 1") +{ + GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_GEMM_FORWARD, CCV_NNC_BACKEND_MPS)); + dsfmt_t dsfmt; + dsfmt_init_gen_rand(&dsfmt, 0); + ccv_nnc_tensor_t* a = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, 1, 128), 0); + ccv_nnc_tensor_t* w = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, 64, 128), 0); + ccv_nnc_tensor_t* b = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, 1, 64), 0); + + ccv_nnc_tensor_t* ha = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 1, 128), 0); + ccv_nnc_tensor_t* hw = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 64, 128), 0); + ccv_nnc_tensor_t* hb = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 1, 64), 0); + int i; + for (i = 0; i < 64 * 128; i++) + hw->data.f32[i] = dsfmt_genrand_open_close(&dsfmt) / (64 * 128); + ccv_nnc_tensor_t* ha1 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 1, 128), 0); + for (i = 0; i < 128; i++) + ha1->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); + for (i = 0; i < 128; i++) + ha->data.f32[i] = ha1->data.f32[i]; + ccv_nnc_tensor_t* ha2 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, 1, 128), 0); + ccv_nnc_tensor_t* hw2 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, 64, 128), 0); + ccv_nnc_cmd_exec(CMD_DATATYPE_CONVERSION_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(ha1, hw), TENSOR_LIST(ha2, hw2), 0); + ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(ha2, hw2), TENSOR_LIST(a, w), 0); + ccv_nnc_cmd_exec(CMD_GEMM_FORWARD(NO_TRANSPOSE, TRANSPOSE(0, 1)), ccv_nnc_no_hint, 0, TENSOR_LIST(ha, hw), TENSOR_LIST(hb), 0); + ccv_nnc_cmd_exec(CMD_GEMM_FORWARD(NO_TRANSPOSE, TRANSPOSE(0, 1)), ccv_nnc_no_hint, 0, TENSOR_LIST(a, w), TENSOR_LIST(b), 0); + ccv_nnc_tensor_t* tb = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, 1, 64), 0); + ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(b), TENSOR_LIST(tb), 0); + ccv_nnc_tensor_t* tb1 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 1, 64), 0); + ccv_nnc_cmd_exec(CMD_DATATYPE_CONVERSION_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(tb), TENSOR_LIST(tb1), 0); + REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, tb1->data.f32, hb->data.f32, 64, 1e-3, "GPU computed output should be the same as CPU computed ones"); + ccv_nnc_tensor_free(a); + ccv_nnc_tensor_free(w); + ccv_nnc_tensor_free(b); + ccv_nnc_tensor_free(tb); + ccv_nnc_tensor_free(ha); + ccv_nnc_tensor_free(ha1); + ccv_nnc_tensor_free(tb1); + ccv_nnc_tensor_free(hw); + ccv_nnc_tensor_free(hb); + ccv_nnc_tensor_free(ha2); + ccv_nnc_tensor_free(hw2); +} + +TEST_CASE("mps forward gemv in half precision no bias, variant 2") +{ + GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_GEMM_FORWARD, CCV_NNC_BACKEND_MPS)); + dsfmt_t dsfmt; + dsfmt_init_gen_rand(&dsfmt, 0); + ccv_nnc_tensor_t* w = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, 64, 128), 0); + ccv_nnc_tensor_t* a = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, 128, 1), 0); + ccv_nnc_tensor_t* b = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, 64, 1), 0); + + ccv_nnc_tensor_t* hw = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 64, 128), 0); + ccv_nnc_tensor_t* ha = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 128, 1), 0); + ccv_nnc_tensor_t* hb = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 64, 1), 0); + int i; + for (i = 0; i < 64 * 128; i++) + hw->data.f32[i] = dsfmt_genrand_open_close(&dsfmt) / (64 * 128); + ccv_nnc_tensor_t* ha1 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 128, 1), 0); + for (i = 0; i < 128; i++) + ha1->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); + for (i = 0; i < 128; i++) + ha->data.f32[i] = ha1->data.f32[i]; + ccv_nnc_tensor_t* hw2 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, 64, 128), 0); + ccv_nnc_tensor_t* ha2 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, 128, 1), 0); + ccv_nnc_cmd_exec(CMD_DATATYPE_CONVERSION_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(ha1, hw), TENSOR_LIST(ha2, hw2), 0); + ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(ha2, hw2), TENSOR_LIST(a, w), 0); + ccv_nnc_cmd_exec(CMD_GEMM_FORWARD(NO_TRANSPOSE, NO_TRANSPOSE), ccv_nnc_no_hint, 0, TENSOR_LIST(hw, ha), TENSOR_LIST(hb), 0); + ccv_nnc_cmd_exec(CMD_GEMM_FORWARD(NO_TRANSPOSE, NO_TRANSPOSE), ccv_nnc_no_hint, 0, TENSOR_LIST(w, a), TENSOR_LIST(b), 0); + ccv_nnc_tensor_t* tb = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, 64, 1), 0); + ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(b), TENSOR_LIST(tb), 0); + ccv_nnc_tensor_t* tb1 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 64, 1), 0); + ccv_nnc_cmd_exec(CMD_DATATYPE_CONVERSION_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(tb), TENSOR_LIST(tb1), 0); + REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, tb1->data.f32, hb->data.f32, 64, 1e-3, "GPU computed output should be the same as CPU computed ones"); + ccv_nnc_tensor_free(a); + ccv_nnc_tensor_free(w); + ccv_nnc_tensor_free(b); + ccv_nnc_tensor_free(tb); + ccv_nnc_tensor_free(ha); + ccv_nnc_tensor_free(ha1); + ccv_nnc_tensor_free(tb1); + ccv_nnc_tensor_free(hw); + ccv_nnc_tensor_free(hb); + ccv_nnc_tensor_free(ha2); + ccv_nnc_tensor_free(hw2); +} + TEST_CASE("mps handle permute") { GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_GEMM_FORWARD, CCV_NNC_BACKEND_MPS)); @@ -904,7 +1044,7 @@ TEST_CASE("generalized batched gemm with batch (2, 4) with bias and broadcast co ccv_nnc_tensor_free(bt); } -TEST_CASE("generalized batched backward gemm with batch (2, 4) compare cublas") +TEST_CASE("generalized batched backward gemm with batch (2, 4) compare mps") { GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_GEMM_BACKWARD, CCV_NNC_BACKEND_MPS)); // This is a particular batched gemm which treat every dimensions other than the last two as batching. @@ -970,7 +1110,7 @@ TEST_CASE("generalized batched backward gemm with batch (2, 4) compare cublas") ccv_nnc_tensor_free(tdw); } -TEST_CASE("generalized batched backward gemm with batch (2, 4) and broadcast compare cublas") +TEST_CASE("generalized batched backward gemm with batch (2, 4) and broadcast compare mps") { GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_GEMM_BACKWARD, CCV_NNC_BACKEND_MPS)); // This is a particular batched gemm which treat every dimensions other than the last two as batching. @@ -1026,7 +1166,7 @@ TEST_CASE("generalized batched backward gemm with batch (2, 4) and broadcast com ccv_nnc_tensor_free(tdw); } -TEST_CASE("generalized batched backward gemm with batch (2, 4) with bias compare cublas") +TEST_CASE("generalized batched backward gemm with batch (2, 4) with bias compare mps") { GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_GEMM_BACKWARD, CCV_NNC_BACKEND_MPS)); // This is a particular batched gemm which treat every dimensions other than the last two as batching. @@ -1099,7 +1239,7 @@ TEST_CASE("generalized batched backward gemm with batch (2, 4) with bias compare ccv_nnc_tensor_free(tdbias); } -TEST_CASE("generalized batched backward gemm with batch (2, 4) with bias and broadcast compare cublas") +TEST_CASE("generalized batched backward gemm with batch (2, 4) with bias and broadcast compare mps") { GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_GEMM_BACKWARD, CCV_NNC_BACKEND_MPS)); // This is a particular batched gemm which treat every dimensions other than the last two as batching.