Skip to content

Commit

Permalink
Move add to v2 AddKernel.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Nov 11, 2024
1 parent c338ac3 commit bb0cfeb
Show file tree
Hide file tree
Showing 9 changed files with 303 additions and 207 deletions.
6 changes: 0 additions & 6 deletions lib/nnc/mfa/ccv_nnc_mfa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,6 @@ void mfa::cache<mfa::gemv::hash, mfa::gemv::pipeline>::prepare(mfa::context* con
_mfa_cache_prepare(&map, context, hash);
}

template <>
void mfa::cache<mfa::add::hash, mfa::add::pipeline>::prepare(mfa::context* context, mfa::add::hash hash)
{
_mfa_cache_prepare(&map, context, hash);
}

mfa::context::context(MTL::Device* device)
{
auto* pool = NS::AutoreleasePool::alloc()->init();
Expand Down
1 change: 0 additions & 1 deletion lib/nnc/mfa/ccv_nnc_mfa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ class context {
cache<depalettize::hash, depalettize::pipeline> depalettize_cache;
cache<adam::hash, adam::pipeline> adam_cache;
cache<gemv::hash, gemv::pipeline> gemv_cache;
cache<add::hash, add::pipeline> add_cache;

ShaderCache v2_cache;

Expand Down
189 changes: 32 additions & 157 deletions lib/nnc/mfa/ccv_nnc_mfa_add.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "ccv_nnc_mfa.hpp"
#include "ccv_nnc_mfa_hash.hpp"
#include "v2/AddDescriptor.hpp"
#include "v2/AddKernel.hpp"
#include <simd/simd.h>
using namespace ccv::nnc;

Expand All @@ -9,18 +11,11 @@ using namespace ccv::nnc;

void ccv_nnc_mfa_prepare_add(mfa::context* context, ccv_nnc_mfa_add_params_t params)
{
context->add_cache.prepare(context, mfa::add::hash(params));
// Do nothing now.
}

void ccv_nnc_mfa_encode_add(ccv_nnc_mfa_context_t* context, ccv_nnc_mfa_add_params_t params, mtl_command_batch_t* command_batch, mtl_buffer_t** tensors, size_t* tensor_offsets)
{
mfa::add::hash hash(params);
auto iterator = context->add_cache.map.find(hash);
if (iterator == context->add_cache.map.end()) {
mfa::precondition_failure("add hash not cached.", __LINE__, __FILE__, __FUNCTION__);
}

auto* pipeline = iterator->second;
auto encoder = command_batch->startCommand();

int num_tensors = 0;
Expand All @@ -29,8 +24,29 @@ void ccv_nnc_mfa_encode_add(ccv_nnc_mfa_context_t* context, ccv_nnc_mfa_add_para
num_tensors += 1;
}
CCV_NNC_MFA_PRECONDITION(num_tensors == 3);

AddDescriptor descriptor;
descriptor.memoryPrecision = (params.data_type == MTL::DataTypeFloat) ? GEMMOperandPrecision::FP32 : GEMMOperandPrecision::FP16;
descriptor.length = params.length;

if (params.length % (4 * 256) == 0) {
descriptor.value = 0;
} else if (params.length % 4 == 0) {
descriptor.value = 1;
} else {
descriptor.value = 2;
}

auto pool = NS::AutoreleasePool::alloc()->init();
auto &shaderCache = context->v2_cache;
DeviceProperties dprops = DeviceProperties();
auto pipelineValue = shaderCache.findKernel<AddKernel, AddDescriptor, AddKernelDescriptor>(descriptor, context->device.get(), dprops);
pool->drain();
auto kernel = pipelineValue->kernel;
auto pipeline = pipelineValue->pipeline;

encoder->setComputePipelineState(pipeline.get());

encoder->setComputePipelineState(pipeline->add_pso.get());
if (tensors[0] == tensors[2]) {
encoder->useResource(tensors[0], MTL::ResourceUsageRead | MTL::ResourceUsageWrite);
encoder->useResource(tensors[1], MTL::ResourceUsageRead);
Expand All @@ -43,156 +59,15 @@ void ccv_nnc_mfa_encode_add(ccv_nnc_mfa_context_t* context, ccv_nnc_mfa_add_para
encoder->useResource(tensors[2], MTL::ResourceUsageWrite);
}

auto grid_size = pipeline->grid_size;
CCV_NNC_MFA_PRECONDITION(grid_size.depth > 0);
encoder->dispatchThreadgroups(grid_size, pipeline->group_size);
command_batch->finishCommand(encoder);
}

// MARK: - C++

mfa::add::hash::hash(ccv_nnc_mfa_add_params_t params) {
data_type = params.data_type;
length = params.length;
}

bool mfa::add::hash::operator==(const mfa::add::hash& hash) const {
return (data_type == hash.data_type) && (length == hash.length);
}

std::ostream& operator<<(std::ostream& os, const mfa::add::hash& hash) {
os << "mfa::add::hash {";
os << " .data_type = " << hash.data_type << ',';
os << " .length = " << hash.length << " ";
os << "}";
return os;
}

std::size_t std::hash<mfa::add::hash>::operator()(const mfa::add::hash& hash) const noexcept {
std::size_t seed = 0;
using namespace mfa::hash;
combine_64(seed, hash.data_type);
combine_32(seed, hash.length);
return seed;
}

mfa::add::pipeline::pipeline(mfa::context* context, mfa::add::hash hash) {
CCV_NNC_MFA_PRECONDITION((hash.data_type == MTL::DataTypeFloat) || (hash.data_type == MTL::DataTypeHalf))

auto* pool = NS::AutoreleasePool::alloc()->init();

std::string shader;
// In this case, we can igore the boundary check.
if (hash.length % (4 * 256) == 0) {
shader = R"(
#include <metal_stdlib>
using namespace metal;
kernel void add(
device const real4 *src0 [[buffer(0)]],
device const real4 *src1 [[buffer(1)]],
device real4 *dst [[buffer(2)]],
uint3 tpig [[thread_position_in_grid]]
) {
const uint idx = tpig.x;
dst[idx] = src0[idx] + src1[idx];
}
)";
} else if (hash.length % 4 == 0) {
shader = R"(
#include <metal_stdlib>
using namespace metal;
kernel void add(
device const real4 *src0 [[buffer(0)]],
device const real4 *src1 [[buffer(1)]],
device real4 *dst [[buffer(2)]],
uint3 tpig [[thread_position_in_grid]]
) {
const uint idx = tpig.x;
if (idx >= count)
return;
dst[idx] = src0[idx] + src1[idx];
}
)";
} else {
shader = R"(
#include <metal_stdlib>
using namespace metal;
kernel void add(
device const real *src0 [[buffer(0)]],
device const real *src1 [[buffer(1)]],
device real *dst [[buffer(2)]],
uint3 tpig [[thread_position_in_grid]]
) {
const uint idx = tpig.x;
if (idx >= count)
return;
dst[idx] = src0[idx] + src1[idx];
}
)";
}

std::string defines = "";
if (hash.data_type == MTL::DataTypeFloat) {
defines += std::string("typedef float4 real4;");
defines += "\n";
defines += std::string("typedef float real;");
defines += "\n";
} else {
defines += std::string("typedef half4 real4;");
defines += "\n";
defines += std::string("typedef half real;");
defines += "\n";
}

unsigned int count;
if (hash.length % 4 == 0) {
count = hash.length / 4;
if (params.length % 4 == 0) {
count = params.length / 4;
} else {
count = hash.length;
}
// Only boundary check needs this const in the shader.
if (hash.length % (4 * 256) != 0) {
defines += "constant uint count = ";
defines += std::to_string(count) + ";";
defines += "\n";
count = params.length;
}
this->group_size = MTL::Size(256, 1, 1);
const int num_blocks = (count + 255) / 256;
this->grid_size = MTL::Size(num_blocks, 1, 1);

auto constants = NS::TransferPtr(MTL::FunctionConstantValues::alloc()->init());
NS::SharedPtr<MTL::ComputePipelineState>* pso = &add_pso;

std::string source = defines;
if (METAL_LOG_LEVEL(context) >= 4) {
std::cerr << source << std::endl;
}
source += shader;

NS::Error *error = nullptr;
auto swift_source = NS::String::string(source.c_str(),
NS::UTF8StringEncoding);
auto library = NS::TransferPtr(context->device->newLibrary(swift_source, nullptr, &error));
if (!library) {
CCV_NNC_MFA_CHECK_ERROR(error)
}

auto swift_name = NS::String::string("add", NS::UTF8StringEncoding);
auto function = NS::TransferPtr(library->newFunction(swift_name, constants.get(), &error));
if (!function) {
CCV_NNC_MFA_CHECK_ERROR(error)
}

*pso = NS::TransferPtr(context->device->newComputePipelineState(function.get(), &error));
if (!*pso) {
CCV_NNC_MFA_CHECK_ERROR(error)
}

pool->drain();
MTL::Size gridSize = MTL::Size(num_blocks, 1, 1);
CCV_NNC_MFA_PRECONDITION(gridSize.depth > 0);
encoder->dispatchThreadgroups(gridSize, kernel->threadgroupSize);
command_batch->finishCommand(encoder);
}
42 changes: 0 additions & 42 deletions lib/nnc/mfa/ccv_nnc_mfa_add.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,48 +7,6 @@ typedef struct {
} ccv_nnc_mfa_add_params_t;

#ifdef __cplusplus
#include "nnc/mfa/3rdparty/metal-cpp/Dispatch.hpp"
#include "nnc/mfa/3rdparty/metal-cpp/Metal.hpp"
#include <simd/simd.h>

namespace ccv {
namespace nnc {
namespace mfa {
namespace add {

class hash {
public:
uint64_t data_type;
uint32_t length;

hash(ccv_nnc_mfa_add_params_t);

bool operator==(const hash& rhs) const;
};

class pipeline {
public:
NS::SharedPtr<MTL::ComputePipelineState> add_pso;

MTL::Size grid_size;
MTL::Size group_size;

pipeline(context* context, hash hash);
};

} // namespace add
} // namespace mfa
} // namespace nnc
} // namespace ccv

std::ostream& operator<<(std::ostream& os, const ccv::nnc::mfa::add::hash& hash);

template<>
struct std::hash<ccv::nnc::mfa::add::hash>
{
std::size_t operator()(const ccv::nnc::mfa::add::hash& hash) const noexcept;
};

extern "C" {
#endif // __cplusplus

Expand Down
2 changes: 1 addition & 1 deletion lib/nnc/mfa/makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ include ../../config.mk

CFLAGS := -std=c++17 -O3 -Wall -I"../../" $(CFLAGS)

SRCS := Metal.cpp ccv_nnc_mfa.cpp ccv_nnc_mfa_attention.cpp ccv_nnc_mfa_error.cpp ccv_nnc_mfa_gemm.cpp ccv_nnc_mfa_normalization.cpp ccv_nnc_mfa_depalettize.cpp ccv_nnc_mfa_adam.cpp ccv_nnc_mfa_cmul.cpp ccv_nnc_mfa_gelu.cpp ccv_nnc_mfa_gemv.cpp ccv_nnc_mfa_cast.cpp ccv_nnc_mfa_add.cpp 3rdparty/metal-cpp/Dispatch.cpp v2/CodeWriter.cpp v2/GEMMDescriptor.cpp v2/GEMMKernelDescriptor.cpp v2/GEMMHeaders.cpp v2/GEMMKernel.cpp v2/AttentionDescriptor.cpp v2/AttentionKernelDescriptor.cpp v2/AttentionKernel.cpp v2/CMulDescriptor.cpp v2/CMulKernel.cpp v2/GeluDescriptor.cpp v2/GeluKernel.cpp v2/CastDescriptor.cpp v2/CastKernel.cpp
SRCS := Metal.cpp ccv_nnc_mfa.cpp ccv_nnc_mfa_attention.cpp ccv_nnc_mfa_error.cpp ccv_nnc_mfa_gemm.cpp ccv_nnc_mfa_normalization.cpp ccv_nnc_mfa_depalettize.cpp ccv_nnc_mfa_adam.cpp ccv_nnc_mfa_cmul.cpp ccv_nnc_mfa_gelu.cpp ccv_nnc_mfa_gemv.cpp ccv_nnc_mfa_cast.cpp ccv_nnc_mfa_add.cpp 3rdparty/metal-cpp/Dispatch.cpp v2/CodeWriter.cpp v2/GEMMDescriptor.cpp v2/GEMMKernelDescriptor.cpp v2/GEMMHeaders.cpp v2/GEMMKernel.cpp v2/AttentionDescriptor.cpp v2/AttentionKernelDescriptor.cpp v2/AttentionKernel.cpp v2/CMulDescriptor.cpp v2/CMulKernel.cpp v2/GeluDescriptor.cpp v2/GeluKernel.cpp v2/CastDescriptor.cpp v2/CastKernel.cpp v2/AddDescriptor.cpp v2/AddKernel.cpp

SRC_OBJS := $(patsubst %.c,%.o,$(patsubst %.cpp,%.o,$(SRCS)))

Expand Down
76 changes: 76 additions & 0 deletions lib/nnc/mfa/v2/AddDescriptor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#include "AddDescriptor.hpp"
#include "AddKernel.hpp"
#include "../ccv_nnc_mfa_hash.hpp"
#include "../ccv_nnc_mfa_error.hpp"

bool AddDescriptor::operator==(const AddDescriptor& rhs) const {
return
memoryPrecision == rhs.memoryPrecision &&
value == rhs.value &&
length == rhs.length;
}

std::size_t std::hash<AddDescriptor>::operator()(const AddDescriptor& hash) const noexcept {
using namespace ccv::nnc::mfa::hash;
std::size_t seed = 0;
combine_64(seed, pack_64(simd::uint2 { (unsigned int)hash.memoryPrecision.value, (unsigned int)hash.value }));
combine_64(seed, pack_64(simd::uint2 { (unsigned int)hash.length, 0 }));
return seed;
}

std::pair<AddKernelDescriptor, PipelineValue<AddKernel> *> AddDescriptor::findKernel(MTL::Device *const device, const DeviceProperties &dprops, std::unordered_map<AddKernelDescriptor, std::unique_ptr<AddKernel>> *const libraryCache) const noexcept {
// The caller is not responsible for calling 'delete' on this pointer. The
// reference is saved in the 'libraryCache'. It will be deallocated whenever
// the shader cache itself is cleaned up.
auto createKernel =
[=](AddKernelDescriptor descriptor) -> AddKernel* {
auto iterator = libraryCache->find(descriptor);
if (iterator != libraryCache->end()) {
return iterator->second.get();
} else {
AddKernel* kernel = new AddKernel(descriptor, device);
(*libraryCache)[descriptor] = std::unique_ptr<AddKernel>(kernel);
return kernel;
}
};

AddKernelDescriptor kernelDesc;
kernelDesc.value = value;
kernelDesc.memoryPrecision = memoryPrecision;

// WARNING: The owner must explicitly retain the compute pipeline.
auto createPipeline =
[=](MTL::Library* library) -> MTL::ComputePipelineState* {
// Set the function constants.
auto constants = NS::TransferPtr
(MTL::FunctionConstantValues::alloc()->init());
uint32_t count;
if (value == 0) {
} else if (value == 1) {
count = length / 4;
constants->setConstantValue(&count, MTL::DataTypeUInt, NS::UInteger(0));
} else {
count = length;
constants->setConstantValue(&count, MTL::DataTypeUInt, NS::UInteger(0));
}

NS::String* swiftName = NS::String::string("add", NS::UTF8StringEncoding);
NS::Error* error = nil;

auto function = NS::TransferPtr
(library->newFunction(swiftName, constants.get(), &error));
CCV_NNC_MFA_CHECK_ERROR(error);

auto pipeline = device->newComputePipelineState(function.get(), &error);
CCV_NNC_MFA_CHECK_ERROR(error);
return pipeline;
};
AddKernel* kernel = createKernel(kernelDesc);
auto pipeline = NS::TransferPtr(createPipeline(kernel->library.get()));

// Force the user to retrieve the return value from the cache. We ensure
// the cache takes ownership, and the pointer doesn't become a zombie
// object.
PipelineValue<AddKernel>* output = new PipelineValue<AddKernel> { kernel, pipeline };
return std::make_pair(kernelDesc, output);
}
Loading

0 comments on commit bb0cfeb

Please sign in to comment.