Skip to content

Commit

Permalink
Add custom gemv kernel that performs slight better than MPS.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Feb 26, 2024
1 parent 24e67e0 commit ca881af
Show file tree
Hide file tree
Showing 7 changed files with 571 additions and 6 deletions.
88 changes: 87 additions & 1 deletion lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,9 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
}

ccv_nnc_mfa_context_t* context = ccv_nnc_default_mfa_context();
const int is_mfa_gemv = !is_batched && ((a_rows == 1 && is_transpose_w && (w_rows % 4) == 0) || (!is_transpose_a && w_cols == 1 && (a_cols % 4) == 0));
const int is_mfa_supported =
ccv_nnc_mfa_context_supported(context) && is_contiguous && is_same_dtype && is_supported_dtype && (!is_batched || is_mfa_compatible_batch) && !(ccv_nnc_flags() & CCV_NNC_DISABLE_METAL_FLASH_ATTENTION) && !(ccv_nnc_flags() & CCV_NNC_DISABLE_MFA_GEMM);
ccv_nnc_mfa_context_supported(context) && is_contiguous && is_same_dtype && is_supported_dtype && (!is_batched || is_mfa_compatible_batch) && !(ccv_nnc_flags() & CCV_NNC_DISABLE_METAL_FLASH_ATTENTION) && (is_mfa_gemv || !(ccv_nnc_flags() & CCV_NNC_DISABLE_MFA_GEMM));

size_t a_data_size = 0;
if (CCV_GET_DATA_TYPE(a->info.datatype) == CCV_QX)
Expand Down Expand Up @@ -266,6 +267,91 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
w_data = scratch;
w_dataof = a_data_size;
}
if (is_mfa_gemv)
{
// This is GEMV, use GEMV kernel.
ccv_nnc_mfa_gemv_params_t params;
if (a_rows == 1 && is_transpose_w)
{
params = (ccv_nnc_mfa_gemv_params_t){
.data_type = mtl_data_type,
.ncols = w_rows,
.nrows = w_cols,
.fused_bias = bias ? 1 : 0,
};
} else {
params = (ccv_nnc_mfa_gemv_params_t){
.data_type = mtl_data_type,
.ncols = a_cols,
.nrows = a_rows,
.fused_bias = bias ? 1 : 0,
};
}
ccv_nnc_mfa_prepare_gemv(context, params);

// Creating a new command buffer has a >10 µs penalty CPU-side. Still
// faster the >50 µs penalty for MPSGraph (probably why
// MPSMatrixMultiplication is faster for GEMM).
mtl_command_batch_t* command_batch = ccv_nnc_stream_context_start_command_batch(stream_context);
if (CCV_GET_DATA_TYPE(a->info.datatype) == CCV_QX)
{
mtl_buffer_t* tensors[3] = {
mpgetbuffer((ccv_nnc_tensor_t*)a), // A
(mtl_buffer_t*)scratch, // B
NULL,
};
size_t tensor_offsets[2] = {
a->dataof, // A offset
0, // B offset
};
ccv_nnc_mfa_encode_depalettize(context, a_depalettize_params, command_batch, tensors, tensor_offsets);
}
if (CCV_GET_DATA_TYPE(w->info.datatype) == CCV_QX)
{
mtl_buffer_t* tensors[3] = {
mpgetbuffer((ccv_nnc_tensor_t*)w), // A
(mtl_buffer_t*)scratch, // B
NULL,
};
size_t tensor_offsets[2] = {
w->dataof, // A offset
a_data_size, // B offset
};
ccv_nnc_mfa_encode_depalettize(context, w_depalettize_params, command_batch, tensors, tensor_offsets);
}
mtl_buffer_t* bias_buffer = NULL;
if (bias) {
bias_buffer = mpgetbuffer((ccv_nnc_tensor_t*)bias);
}
mtl_buffer_t* tensors[5] = {
NULL,
NULL,
mpgetbuffer((ccv_nnc_tensor_t*)b), // C
bias_buffer, // D
NULL,
};
size_t tensor_offsets[4] = {
0,
0,
b->dataof, // C offset
bias ? bias->dataof : 0, // D offset
};
if (a_rows == 1 && is_transpose_w)
{
tensors[0] = w_data;
tensors[1] = a_data;
tensor_offsets[0] = w_dataof;
tensor_offsets[1] = a_dataof;
} else {
tensors[0] = a_data;
tensors[1] = w_data;
tensor_offsets[0] = a_dataof;
tensor_offsets[1] = w_dataof;
}
ccv_nnc_mfa_encode_gemv(context, params, command_batch, tensors, tensor_offsets);
ccv_nnc_stream_context_finish_command_batch(stream_context, command_batch);
return CCV_NNC_EXEC_SUCCESS;
}
// On supported devices, use Metal directly.
ccv_nnc_mfa_gemm_params_t params = {
.data_type = mtl_data_type,
Expand Down
6 changes: 6 additions & 0 deletions lib/nnc/mfa/ccv_nnc_mfa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,12 @@ void mfa::cache<mfa::cmul::hash, mfa::cmul::pipeline>::prepare(mfa::context* con
_mfa_cache_prepare(&map, context, hash);
}

template <>
void mfa::cache<mfa::gemv::hash, mfa::gemv::pipeline>::prepare(mfa::context* context, mfa::gemv::hash hash)
{
_mfa_cache_prepare(&map, context, hash);
}

mfa::context::context(MTL::Device* device)
{
auto* pool = NS::AutoreleasePool::alloc()->init();
Expand Down
2 changes: 2 additions & 0 deletions lib/nnc/mfa/ccv_nnc_mfa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "ccv_nnc_mfa_depalettize.hpp"
#include "ccv_nnc_mfa_adam.hpp"
#include "ccv_nnc_mfa_cmul.hpp"
#include "ccv_nnc_mfa_gemv.hpp"

#ifdef __cplusplus
#include "nnc/mfa/3rdparty/metal-cpp/Dispatch.hpp"
Expand Down Expand Up @@ -50,6 +51,7 @@ class context {
cache<depalettize::hash, depalettize::pipeline> depalettize_cache;
cache<adam::hash, adam::pipeline> adam_cache;
cache<cmul::hash, cmul::pipeline> cmul_cache;
cache<gemv::hash, gemv::pipeline> gemv_cache;

MTL::Buffer* request_scratch(uint64_t size);
};
Expand Down
265 changes: 265 additions & 0 deletions lib/nnc/mfa/ccv_nnc_mfa_gemv.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
#include "ccv_nnc_mfa.hpp"
#include "ccv_nnc_mfa_hash.hpp"
#include <simd/simd.h>
using namespace ccv::nnc;

#include <string>

// MARK: - C

void ccv_nnc_mfa_prepare_gemv(mfa::context* context, ccv_nnc_mfa_gemv_params_t params)
{
context->gemv_cache.prepare(context, mfa::gemv::hash(params));
}

void ccv_nnc_mfa_encode_gemv(ccv_nnc_mfa_context_t* context, ccv_nnc_mfa_gemv_params_t params, mtl_command_batch_t* command_batch, mtl_buffer_t** tensors, size_t* tensor_offsets)
{
mfa::gemv::hash hash(params);
auto iterator = context->gemv_cache.map.find(hash);
if (iterator == context->gemv_cache.map.end()) {
mfa::precondition_failure("gemv hash not cached.", __LINE__, __FILE__, __FUNCTION__);
}

auto* pipeline = iterator->second;
auto encoder = command_batch->startCommand();

int num_tensors = 0;
while (tensors[num_tensors] != nullptr) {
encoder->setBuffer(tensors[num_tensors], tensor_offsets[num_tensors], NS::UInteger(num_tensors));
num_tensors += 1;
}
CCV_NNC_MFA_PRECONDITION(num_tensors == 3 || num_tensors == 4);

encoder->setComputePipelineState(pipeline->gemv_pso.get());
encoder->useResource(tensors[0], MTL::ResourceUsageRead);
encoder->useResource(tensors[1], MTL::ResourceUsageRead);
encoder->useResource(tensors[2], MTL::ResourceUsageWrite);
if (num_tensors == 4) {
encoder->useResource(tensors[3], MTL::ResourceUsageRead);
}

auto grid_size = pipeline->grid_size;
CCV_NNC_MFA_PRECONDITION(grid_size.depth > 0);
encoder->dispatchThreadgroups(grid_size, pipeline->group_size);
command_batch->finishCommand(encoder);
}

// MARK: - C++

mfa::gemv::hash::hash(ccv_nnc_mfa_gemv_params_t params) {
data_type = params.data_type;
nrows = params.nrows;
ncols = params.ncols;
fused_bias = params.fused_bias;
}

bool mfa::gemv::hash::operator==(const mfa::gemv::hash& hash) const {
return
(data_type == hash.data_type) &&
(nrows == hash.nrows) &&
(ncols == hash.ncols) &&
(fused_bias == hash.fused_bias);
}

std::ostream& operator<<(std::ostream& os, const mfa::gemv::hash& hash) {
os << "mfa::gemv::hash {";
os << " .data_type = " << hash.data_type << ',';
os << " .nrows = " << hash.nrows << ',';
os << " .ncols = " << hash.ncols << ',';
os << " .fused_bias = " << bool(hash.fused_bias) << " ";
os << "}";
return os;
}

std::size_t std::hash<mfa::gemv::hash>::operator()(const mfa::gemv::hash& hash) const noexcept {
std::size_t seed = 0;
using namespace mfa::hash;
combine_64(seed, hash.data_type);
combine_64(seed, pack_64(simd::uint2 { (unsigned int)hash.nrows, (unsigned int)hash.ncols }));
combine_32(seed, pack_32(simd::uchar4 { hash.fused_bias, 0, 0, 0 }));
return seed;
}

mfa::gemv::pipeline::pipeline(mfa::context* context, mfa::gemv::hash hash) {
// FlashNorm not supported for group gemv yet.
CCV_NNC_MFA_PRECONDITION((hash.data_type == MTL::DataTypeFloat) || (hash.data_type == MTL::DataTypeHalf))

auto* pool = NS::AutoreleasePool::alloc()->init();

std::string shader;
if (hash.fused_bias) {
shader = R"(
#include <metal_stdlib>
using namespace metal;
kernel void gemv(
device const real *src0 [[buffer(0)]],
device const real *src1 [[buffer(1)]],
device real *dst [[buffer(2)]],
device const real *bias [[buffer(3)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
const int64_t rb = tgpig * N;
device const real* y = (device const real*)src1;
if (ncols < 128) {
for (uint row = 0; row < N; ++row) {
uint r1 = rb + row;
if (r1 >= nrows) {
break;
}
device const real* x = (device const real*)src0 + r1 * ncols;
float sumf = 0;
for (uint i = tiisg; i < ncols; i += 32) {
sumf += (real)x[i] * (real)y[i];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
dst[r1] = bias[r1] + all_sum;
}
}
} else {
device const real4* y4 = (device const real4*)y;
for (uint row = 0; row < N; ++row) {
uint r1 = rb + row;
if (r1 >= nrows) {
break;
}
device const real* x = (device const real*)src0 + r1 * ncols;
device const real4* x4 = (device const real4*)x;
float sumf = 0;
for (uint i = tiisg; i < ncols / 4; i += 32) {
sumf += (real)x4[i][0] * y4[i][0];
sumf += (real)x4[i][1] * y4[i][1];
sumf += (real)x4[i][2] * y4[i][2];
sumf += (real)x4[i][3] * y4[i][3];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
dst[r1] = bias[r1] + all_sum;
}
}
}
}
)";
} else {
shader = R"(
#include <metal_stdlib>
using namespace metal;
kernel void gemv(
device const real *src0 [[buffer(0)]],
device const real *src1 [[buffer(1)]],
device real *dst [[buffer(2)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
const int64_t rb = tgpig * N;
device const real* y = (device const real*)src1;
if (ncols < 128) {
for (uint row = 0; row < N; ++row) {
uint r1 = rb + row;
if (r1 >= nrows) {
break;
}
device const real* x = (device const real*)src0 + r1 * ncols;
float sumf = 0;
for (uint i = tiisg; i < ncols; i += 32) {
sumf += (real)x[i] * (real)y[i];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
dst[r1] = all_sum;
}
}
} else {
device const real4* y4 = (device const real4*)y;
for (uint row = 0; row < N; ++row) {
uint r1 = rb + row;
if (r1 >= nrows) {
break;
}
device const real* x = (device const real*)src0 + r1 * ncols;
device const real4* x4 = (device const real4*)x;
float sumf = 0;
for (uint i = tiisg; i < ncols / 4; i += 32) {
sumf += (real)x4[i][0] * y4[i][0];
sumf += (real)x4[i][1] * y4[i][1];
sumf += (real)x4[i][2] * y4[i][2];
sumf += (real)x4[i][3] * y4[i][3];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
dst[r1] = all_sum;
}
}
}
}
)";
}

std::string defines = "";
if (hash.data_type == MTL::DataTypeFloat) {
defines += std::string("typedef float real;");
defines += "\n";
defines += std::string("typedef float4 real4;");
defines += "\n";
} else {
defines += std::string("typedef half real;");
defines += "\n";
defines += std::string("typedef half4 real4;");
defines += "\n";
}

defines += "constant uint N = 8;\n";
defines += "constant uint ncols = ";
defines += std::to_string(hash.ncols) + ";";
defines += "\n";
defines += "constant uint nrows = ";
defines += std::to_string(hash.nrows) + ";";
defines += "\n";
this->group_size = MTL::Size(32, 1, 1);
this->grid_size = MTL::Size((hash.nrows + 8 - 1) / 8, 1, 1);

auto constants = NS::TransferPtr(MTL::FunctionConstantValues::alloc()->init());
NS::SharedPtr<MTL::ComputePipelineState>* pso = &gemv_pso;

std::string source = defines;
if (METAL_LOG_LEVEL(context) >= 4) {
std::cerr << source << std::endl;
}
source += shader;

NS::Error *error = nullptr;
auto swift_source = NS::String::string(source.c_str(),
NS::UTF8StringEncoding);
auto library = NS::TransferPtr(context->device->newLibrary(swift_source, nullptr, &error));
if (!library) {
CCV_NNC_MFA_CHECK_ERROR(error)
}

auto swift_name = NS::String::string("gemv", NS::UTF8StringEncoding);
auto function = NS::TransferPtr(library->newFunction(swift_name, constants.get(), &error));
if (!function) {
CCV_NNC_MFA_CHECK_ERROR(error)
}

*pso = NS::TransferPtr(context->device->newComputePipelineState(function.get(), &error));
if (!*pso) {
CCV_NNC_MFA_CHECK_ERROR(error)
}

pool->drain();
}
Loading

0 comments on commit ca881af

Please sign in to comment.