Skip to content

Commit

Permalink
Removed reference to previous GEMM kernel.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Aug 15, 2024
1 parent 785cf57 commit fc9e662
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 286 deletions.
21 changes: 3 additions & 18 deletions lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,9 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint

ccv_nnc_mfa_context_t* context = ccv_nnc_default_mfa_context();
const int is_mfa_gemv = !is_batched && ((a_rows == 1 && is_transpose_w && (w_rows % 4) == 0) || (!is_transpose_a && w_cols == 1 && (a_cols % 4) == 0));
// v1 only supports the same precision of accumulator as the tensor.
int is_different_accumulator_precision = ((cmd.info.blas.flags & CCV_NNC_GEMM_32F) && a_datatype == CCV_16F) || ((cmd.info.blas.flags & CCV_NNC_GEMM_16F) && a_datatype == CCV_32F);
int is_upcast = ((cmd.info.blas.flags & CCV_NNC_GEMM_32F) && a_datatype == CCV_16F);
const int is_mfa_supported =
ccv_nnc_mfa_context_supported(context) && is_contiguous && is_same_dtype && is_supported_dtype && (!is_batched || is_mfa_compatible_batch) && !(ccv_nnc_flags() & CCV_NNC_DISABLE_METAL_FLASH_ATTENTION) && (is_mfa_gemv || (!(ccv_nnc_flags() & CCV_NNC_DISABLE_MFA_GEMM) && !is_different_accumulator_precision));
ccv_nnc_mfa_context_supported(context) && is_contiguous && is_same_dtype && is_supported_dtype && (!is_batched || is_mfa_compatible_batch) && !(ccv_nnc_flags() & CCV_NNC_DISABLE_METAL_FLASH_ATTENTION) && (is_mfa_gemv || !(ccv_nnc_flags() & CCV_NNC_DISABLE_MFA_GEMM));

size_t a_data_size = 0;
if (CCV_GET_DATA_TYPE(a->info.datatype) == CCV_QX)
Expand Down Expand Up @@ -364,11 +363,9 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
.A_trans = (is_transpose_a ? 1 : 0),
.B_trans = (is_transpose_w ? 1 : 0),
.D_trans = 0,
.alpha = (float)1.0,
.beta = (float)0.0,
.batched = is_batched,
.fused_activation_function = 0,
.fused_bias = (bias ? 1 : 0),
.register_float = (is_upcast ? 1 : 0),

.batch_dims_a = { 0 },
.batch_dims_b = { 0 },
Expand Down Expand Up @@ -795,10 +792,7 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
.A_trans = 1,
.B_trans = (is_transpose_w ? 1 : 0),
.D_trans = 0,
.alpha = (float)1.0,
.beta = (float)0.0,
.batched = is_batched,
.fused_activation_function = 0,
.fused_bias = 0,

.batch_dims_a = { 0 },
Expand Down Expand Up @@ -834,10 +828,7 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
.A_trans = 0,
.B_trans = (is_transpose_w ? 0 : 1),
.D_trans = 0,
.alpha = (float)1.0,
.beta = (float)0.0,
.batched = is_batched,
.fused_activation_function = 0,
.fused_bias = 0,

.batch_dims_a = { 0 },
Expand Down Expand Up @@ -881,10 +872,7 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
.A_trans = 1,
.B_trans = (is_transpose_a ? 1 : 0),
.D_trans = 0,
.alpha = (float)1.0,
.beta = (float)0.0,
.batched = is_batched,
.fused_activation_function = 0,
.fused_bias = 0,

.batch_dims_a = { 0 },
Expand Down Expand Up @@ -920,10 +908,7 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
.A_trans = (is_transpose_a ? 0 : 1),
.B_trans = 0,
.D_trans = 0,
.alpha = (float)1.0,
.beta = (float)0.0,
.batched = is_batched,
.fused_activation_function = 0,
.fused_bias = 0,

.batch_dims_a = { 0 },
Expand Down
6 changes: 0 additions & 6 deletions lib/nnc/cmd/convolution/mps/ccv_nnc_conv_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,7 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
.A_trans = 0,
.B_trans = 1,
.D_trans = 0,
.alpha = (float)1.0,
.beta = (float)0.0,
.batched = is_batched,
.fused_activation_function = 0,
.fused_bias = (bias ? 1 : 0),

.batch_dims_a = { 0 },
Expand All @@ -275,10 +272,7 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
.A_trans = 0,
.B_trans = 0,
.D_trans = 1,
.alpha = (float)1.0,
.beta = (float)0.0,
.batched = is_batched,
.fused_activation_function = 0,
.fused_bias = (bias ? 1 : 0),

.batch_dims_a = { 0 },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,7 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c
.A_trans = false,
.B_trans = true,
.D_trans = false,
.alpha = (float)1.0,
.beta = (float)0.0,
.batched = 0,
.fused_activation_function = 0,
.fused_bias = (bias ? 1 : 0),

.batch_dims_a = { 0 },
Expand Down
6 changes: 0 additions & 6 deletions lib/nnc/mfa/ccv_nnc_mfa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,6 @@ void mfa::cache<mfa::attention::hash, mfa::attention::pipeline>::prepare(mfa::co
_mfa_cache_prepare(&map, context, hash);
}

template <>
void mfa::cache<mfa::gemm::hash, mfa::gemm::pipeline>::prepare(mfa::context* context, mfa::gemm::hash hash)
{
_mfa_cache_prepare(&map, context, hash);
}

template <>
void mfa::cache<mfa::normalization::hash, mfa::normalization::pipeline>::prepare(mfa::context* context, mfa::normalization::hash hash)
{
Expand Down
3 changes: 1 addition & 2 deletions lib/nnc/mfa/ccv_nnc_mfa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
#include "nnc/ccv_nnc.h"
#include "ccv_nnc_mfa_defines.hpp"
#include "ccv_nnc_mfa_attention.hpp"
#include "ccv_nnc_mfa_gemm.hpp"
#include "ccv_nnc_mfa_normalization.hpp"
#include "ccv_nnc_mfa_depalettize.hpp"
#include "ccv_nnc_mfa_adam.hpp"
#include "ccv_nnc_mfa_cmul.hpp"
#include "ccv_nnc_mfa_gemm.hpp"
#include "ccv_nnc_mfa_gemv.hpp"
#include "ccv_nnc_mfa_cast.hpp"
#include "ccv_nnc_mfa_add.hpp"
Expand Down Expand Up @@ -49,7 +49,6 @@ class context {
context(MTL::Device* device);

cache<attention::hash, attention::pipeline> attention_cache;
cache<gemm::hash, gemm::pipeline> gemm_cache;
cache<normalization::hash, normalization::pipeline> normalization_cache;
cache<depalettize::hash, depalettize::pipeline> depalettize_cache;
cache<adam::hash, adam::pipeline> adam_cache;
Expand Down
199 changes: 4 additions & 195 deletions lib/nnc/mfa/ccv_nnc_mfa_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using namespace ccv::nnc;

void ccv_nnc_mfa_prepare_gemm(mfa::context* context, ccv_nnc_mfa_gemm_params_t params)
{
context->gemm_cache.prepare(context, mfa::gemm::hash(params));
// No-op.
}

void ccv_nnc_mfa_encode_gemm(mfa::context* context, ccv_nnc_mfa_gemm_params_t params, MTL::CommandBatch* command_batch, MTL::Buffer** tensors, size_t* tensor_offsets)
Expand Down Expand Up @@ -55,6 +55,8 @@ void ccv_nnc_mfa_encode_gemm(mfa::context* context, ccv_nnc_mfa_gemm_params_t pa
break;
}
gemmDesc.transposeState = simd::uchar3 { params.A_trans, params.B_trans, params.D_trans };
gemmDesc.registerPrecisionC = (params.register_float) ? std::optional(GEMMOperandPrecision::FP32) : std::nullopt;
gemmDesc.leadingDimensions = std::nullopt;
gemmDesc.loadPreviousC = false;
gemmDesc.useBias = params.fused_bias;
if (params.batched) {
Expand All @@ -71,7 +73,7 @@ void ccv_nnc_mfa_encode_gemm(mfa::context* context, ccv_nnc_mfa_gemm_params_t pa
continue;
} else if (operand == 3) {
// Skip the D operand if unavailable.
if (!(params.fused_activation_function || params.fused_bias)) {
if (!params.fused_bias) {
continue;
}
batch_dims = params.batch_dims_d;
Expand Down Expand Up @@ -161,196 +163,3 @@ void ccv_nnc_mfa_encode_gemm(mfa::context* context, ccv_nnc_mfa_gemm_params_t pa
command_batch->finishCommand(encoder);
}

// MARK: - C++

mfa::gemm::hash::hash(ccv_nnc_mfa_gemm_params_t params) {
data_type = params.data_type;
M = params.M;
N = params.N;
K = params.K;
A_trans = params.A_trans;
B_trans = params.B_trans;
D_trans = params.D_trans;
alpha = params.alpha;
beta = params.beta;
batched = params.batched;
fused_activation_function = params.fused_activation_function;
fused_bias = params.fused_bias;
}

bool mfa::gemm::hash::operator==(const mfa::gemm::hash& hash) const {
return
(data_type == hash.data_type) &&
(M == hash.M) &&
(N == hash.N) &&
(K == hash.K) &&
(A_trans == hash.A_trans) &&
(B_trans == hash.B_trans) &&
(D_trans == hash.D_trans) &&
(alpha == hash.alpha) &&
(beta == hash.beta) &&
(batched == hash.batched) &&
(fused_activation_function == hash.fused_activation_function) &&
(fused_bias == hash.fused_bias);
}

std::ostream& operator<<(std::ostream& os, const mfa::gemm::hash& hash) {
os << "mfa::gemm::hash {";
os << " .data_type = " << hash.data_type << ',';
os << " .M = " << hash.M << ',';
os << " .N = " << hash.N << ',';
os << " .K = " << hash.K << ',';
os << " .A_trans = " << bool(hash.A_trans) << ',';
os << " .B_trans = " << bool(hash.B_trans) << ',';
os << " .D_trans = " << bool(hash.D_trans) << ',';
os << " .alpha = " << double(hash.alpha) << ',';
os << " .beta = " << double(hash.beta) << ',';
os << " .batched = " << bool(hash.batched) << ',';
os << " .fused_activation_function = " << bool(hash.fused_activation_function) << ',';
os << " .fused_bias = " << bool(hash.fused_bias) << " ";
os << "}";
return os;
}

std::size_t std::hash<mfa::gemm::hash>::operator()(const mfa::gemm::hash& hash) const noexcept {
std::size_t seed = 0;
using namespace mfa::hash;
combine_64(seed, hash.data_type);
combine_64(seed, pack_64(simd::uint2 { hash.M, hash.N }));
combine_64(seed, pack_64(simd::uint2 { hash.K, pack_32(simd::uchar4 { hash.A_trans, hash.B_trans, hash.D_trans, 0 }) }));
combine_64(seed, pack_64(simd::uint2 { *reinterpret_cast<const uint32_t*>(&hash.alpha), *reinterpret_cast<const uint32_t*>(&hash.beta) }));
combine_32(seed, pack_32(simd::uchar4 { hash.batched, hash.fused_activation_function, hash.fused_bias, 0 }));
return seed;
}

mfa::gemm::pipeline::pipeline(mfa::context* context, mfa::gemm::hash hash) {
CCV_NNC_MFA_PRECONDITION((hash.data_type == MTL::DataTypeFloat) || (hash.data_type == MTL::DataTypeHalf))
CCV_NNC_MFA_PRECONDITION(hash.alpha == 1.0)
CCV_NNC_MFA_PRECONDITION(hash.beta == 0.0)
CCV_NNC_MFA_PRECONDITION(hash.fused_activation_function == false)

auto* pool = NS::AutoreleasePool::alloc()->init();

auto constants = NS::TransferPtr(MTL::FunctionConstantValues::alloc()->init());
constants->setConstantValue(&hash.M, MTL::DataTypeUInt, NS::UInteger(0));
constants->setConstantValue(&hash.N, MTL::DataTypeUInt, 1);
constants->setConstantValue(&hash.K, MTL::DataTypeUInt, 2);
constants->setConstantValue(&hash.A_trans, MTL::DataTypeBool, 10);
constants->setConstantValue(&hash.B_trans, MTL::DataTypeBool, 11);
constants->setConstantValue(&hash.D_trans, MTL::DataTypeBool, 13);
constants->setConstantValue(&hash.alpha, MTL::DataTypeFloat, 20);
constants->setConstantValue(&hash.beta, MTL::DataTypeFloat, 21);
constants->setConstantValue(&hash.batched, MTL::DataTypeBool, 100);
constants->setConstantValue(&hash.fused_activation_function, MTL::DataTypeBool, 101);
constants->setConstantValue(&hash.fused_bias, MTL::DataTypeBool, 50001);
simd::ulong4 garbage(0);
constants->setConstantValue(&garbage, MTL::DataTypeBool, 102);
constants->setConstantValue(&garbage, MTL::DataTypeBool, 103);
constants->setConstantValue(&garbage, MTL::DataTypeBool, 113);
constants->setConstantValue(&garbage, MTL::DataTypeBool, 50000);

// Eventually, this may incorporate the batch size.
// BxMxN > 1,000,000 -> 48x48, only if M >= 88 and N >= 88
// BxMxN > 4,000,000 -> 64x64, only if M >= 120 and N >= 120
uint64_t C_elements = uint64_t(hash.M) * uint64_t(hash.N);
if (hash.batched) {
C_elements *= 2;
}
int is_half = (hash.data_type == MTL::DataTypeHalf); // SD v1 attention
int is_float = (hash.data_type == MTL::DataTypeFloat); // SD v2 attention

uint16_t M_group = 32;
uint16_t N_group = 32;
uint16_t K_simd = 32;
if (C_elements > 1000 * 1000) {
M_group = 48;
N_group = 48;
}

// If K_simd is perfectly equal to matrix K, the compiler can elide a large
// amount of logic in the kernel.
if (hash.K >= 33 && hash.K <= 40) {
K_simd = 40; // 1 * 40
} else if (is_half && hash.K >= 73 && hash.K <= 80) {
K_simd = 40; // 2 * 40
} else if (C_elements > 1000 * 1000) {
if (hash.K <= 24) {
K_simd = 24; // 1 * 24
} else if (hash.K <= 32) {
K_simd = 32; // 1 * 32
} else if (hash.K <= 48) {
K_simd = 24;
} else if (hash.K <= 64) {
K_simd = 32;
} else if (is_float) {
K_simd = 24;
}
}

uint16_t M_splits = 2;
uint16_t N_splits = 2;
uint16_t M_simd = M_group / M_splits;
uint16_t N_simd = N_group / N_splits;

constants->setConstantValue(&M_simd, MTL::DataTypeUShort, 200);
constants->setConstantValue(&N_simd, MTL::DataTypeUShort, 201);
constants->setConstantValue(&K_simd, MTL::DataTypeUShort, 202);
constants->setConstantValue(&M_splits, MTL::DataTypeUShort, 210);
constants->setConstantValue(&N_splits, MTL::DataTypeUShort, 211);

std::string cpp_name;
uint16_t data_type_size = UINT16_MAX;
switch (hash.data_type) {
case MTL::DataTypeHalf: {
cpp_name = "hgemm";
data_type_size = 2;
break;
}
case MTL::DataTypeFloat: {
cpp_name = "sgemm";
data_type_size = 4;
break;
}
default: {
CCV_NNC_MFA_PRECONDITION(false)
break;
}
}
auto* swift_name = NS::String::string(cpp_name.c_str(), NS::UTF8StringEncoding);

uint16_t A_block_bytes = M_group * K_simd * data_type_size;
uint16_t B_block_bytes = K_simd * N_group * data_type_size;
uint16_t C_block_bytes = M_group * N_group * data_type_size;
threadgroup_memory_length = A_block_bytes + B_block_bytes;

if ((hash.M % 8 > 0) && (hash.N % 8 > 0)) {
if (C_block_bytes > threadgroup_memory_length) {
threadgroup_memory_length = C_block_bytes;
}
}
if (hash.fused_bias) {
uint16_t D_block_bytes = (hash.D_trans ? M_group : N_group) * data_type_size;
if (D_block_bytes > threadgroup_memory_length) {
threadgroup_memory_length = D_block_bytes;
}
}

std::function<size_t(size_t, uint16_t)> ceil_divide = [](size_t original, uint16_t granularity) {
return (original + size_t(granularity) - 1) / size_t(granularity);
};
grid_size = MTL::Size(ceil_divide(hash.N, N_group), ceil_divide(hash.M, M_group), 1);
group_size = MTL::Size(32 * M_splits * N_splits, 1, 1);

NS::Error* error = nullptr;
auto function = NS::TransferPtr(context->library->newFunction(swift_name, constants.get(), &error));
if (!function) {
CCV_NNC_MFA_CHECK_ERROR(error)
}

pso = NS::TransferPtr(context->device->newComputePipelineState(function.get(), &error));
if (!pso) {
CCV_NNC_MFA_CHECK_ERROR(error)
}

pool->drain();
}
Loading

0 comments on commit fc9e662

Please sign in to comment.