Skip to content

Commit

Permalink
Move cast code to CastKernel
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Nov 11, 2024
1 parent 0539805 commit c338ac3
Show file tree
Hide file tree
Showing 12 changed files with 329 additions and 290 deletions.
6 changes: 0 additions & 6 deletions lib/nnc/mfa/ccv_nnc_mfa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,6 @@ void mfa::cache<mfa::gemv::hash, mfa::gemv::pipeline>::prepare(mfa::context* con
_mfa_cache_prepare(&map, context, hash);
}

template <>
void mfa::cache<mfa::cast::hash, mfa::cast::pipeline>::prepare(mfa::context* context, mfa::cast::hash hash)
{
_mfa_cache_prepare(&map, context, hash);
}

template <>
void mfa::cache<mfa::add::hash, mfa::add::pipeline>::prepare(mfa::context* context, mfa::add::hash hash)
{
Expand Down
1 change: 0 additions & 1 deletion lib/nnc/mfa/ccv_nnc_mfa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ class context {
cache<depalettize::hash, depalettize::pipeline> depalettize_cache;
cache<adam::hash, adam::pipeline> adam_cache;
cache<gemv::hash, gemv::pipeline> gemv_cache;
cache<cast::hash, cast::pipeline> cast_cache;
cache<add::hash, add::pipeline> add_cache;

ShaderCache v2_cache;
Expand Down
198 changes: 32 additions & 166 deletions lib/nnc/mfa/ccv_nnc_mfa_cast.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "ccv_nnc_mfa.hpp"
#include "ccv_nnc_mfa_hash.hpp"
#include "v2/CastDescriptor.hpp"
#include "v2/CastKernel.hpp"
#include <simd/simd.h>
using namespace ccv::nnc;

Expand All @@ -9,18 +11,11 @@ using namespace ccv::nnc;

void ccv_nnc_mfa_prepare_cast(mfa::context* context, ccv_nnc_mfa_cast_params_t params)
{
context->cast_cache.prepare(context, mfa::cast::hash(params));
// Do nothing now.
}

void ccv_nnc_mfa_encode_cast(ccv_nnc_mfa_context_t* context, ccv_nnc_mfa_cast_params_t params, mtl_command_batch_t* command_batch, mtl_buffer_t** tensors, size_t* tensor_offsets)
{
mfa::cast::hash hash(params);
auto iterator = context->cast_cache.map.find(hash);
if (iterator == context->cast_cache.map.end()) {
mfa::precondition_failure("cast hash not cached.", __LINE__, __FILE__, __FUNCTION__);
}

auto* pipeline = iterator->second;
auto encoder = command_batch->startCommand();

int num_tensors = 0;
Expand All @@ -29,176 +24,47 @@ void ccv_nnc_mfa_encode_cast(ccv_nnc_mfa_context_t* context, ccv_nnc_mfa_cast_pa
num_tensors += 1;
}
CCV_NNC_MFA_PRECONDITION(num_tensors == 2);

encoder->setComputePipelineState(pipeline->cast_pso.get());
encoder->useResource(tensors[0], MTL::ResourceUsageRead);
encoder->useResource(tensors[1], MTL::ResourceUsageWrite);

auto grid_size = pipeline->grid_size;
CCV_NNC_MFA_PRECONDITION(grid_size.depth > 0);
encoder->dispatchThreadgroups(grid_size, pipeline->group_size);
command_batch->finishCommand(encoder);
}

// MARK: - C++

mfa::cast::hash::hash(ccv_nnc_mfa_cast_params_t params) {
original_data_type = params.original_data_type;
data_type = params.data_type;
length = params.length;
}

bool mfa::cast::hash::operator==(const mfa::cast::hash& hash) const {
return
(original_data_type == hash.original_data_type) &&
(data_type == hash.data_type) &&
(length == hash.length);
}

std::ostream& operator<<(std::ostream& os, const mfa::cast::hash& hash) {
os << "mfa::cast::hash {";
os << " .original_data_type = " << hash.original_data_type << ',';
os << " .data_type = " << hash.data_type << ',';
os << " .length = " << hash.length << " ";
os << "}";
return os;
}

std::size_t std::hash<mfa::cast::hash>::operator()(const mfa::cast::hash& hash) const noexcept {
std::size_t seed = 0;
using namespace mfa::hash;
combine_64(seed, hash.original_data_type);
combine_64(seed, hash.data_type);
combine_32(seed, hash.length);
return seed;
}

mfa::cast::pipeline::pipeline(mfa::context* context, mfa::cast::hash hash) {
CCV_NNC_MFA_PRECONDITION((hash.original_data_type == MTL::DataTypeFloat) || (hash.original_data_type == MTL::DataTypeHalf))
CCV_NNC_MFA_PRECONDITION((hash.data_type == MTL::DataTypeFloat) || (hash.data_type == MTL::DataTypeHalf))

auto* pool = NS::AutoreleasePool::alloc()->init();

std::string shader;
// In this case, we can ignore the boundary check.
if (hash.length % (4 * 256) == 0) {
shader = R"(
#include <metal_stdlib>
using namespace metal;
kernel void cast(
device original_real4 *src [[buffer(0)]],
device real4 *destination [[buffer(1)]],
uint3 tpig [[thread_position_in_grid]]
) {
const uint idx = tpig.x;
destination[idx] = (real4)(src[idx]);
}
)";
} else if (hash.length % 4 == 0) {
shader = R"(
#include <metal_stdlib>
using namespace metal;
CastDescriptor descriptor;
descriptor.fromMemoryPrecision = (params.original_data_type == MTL::DataTypeFloat) ? GEMMOperandPrecision::FP32 : GEMMOperandPrecision::FP16;
descriptor.memoryPrecision = (params.data_type == MTL::DataTypeFloat) ? GEMMOperandPrecision::FP32 : GEMMOperandPrecision::FP16;
descriptor.length = params.length;

kernel void cast(
device original_real4 *src [[buffer(0)]],
device real4 *destination [[buffer(1)]],
uint3 tpig [[thread_position_in_grid]]
) {
const uint idx = tpig.x;
if (idx >= count)
return;
destination[idx] = (real4)(src[idx]);
}
)";
if (params.length % (4 * 256) == 0) {
descriptor.value = 0;
} else if (params.length % 4 == 0) {
descriptor.value = 1;
} else {
shader = R"(
#include <metal_stdlib>
using namespace metal;
kernel void cast(
device original_real *src [[buffer(0)]],
device real *destination [[buffer(1)]],
uint3 tpig [[thread_position_in_grid]]
) {
const uint idx = tpig.x;
if (idx >= count)
return;
destination[idx] = (real)(src[idx]);
}
)";
descriptor.value = 2;
}

std::string defines = "";
if (hash.data_type == MTL::DataTypeFloat) {
defines += std::string("typedef float real;");
defines += "\n";
defines += std::string("typedef float4 real4;");
defines += "\n";
} else {
defines += std::string("typedef half real;");
defines += "\n";
defines += std::string("typedef half4 real4;");
defines += "\n";
}
auto pool = NS::AutoreleasePool::alloc()->init();
auto &shaderCache = context->v2_cache;
DeviceProperties dprops = DeviceProperties();
auto pipelineValue = shaderCache.findKernel<CastKernel, CastDescriptor, CastKernelDescriptor>(descriptor, context->device.get(), dprops);
pool->drain();
auto kernel = pipelineValue->kernel;
auto pipeline = pipelineValue->pipeline;

if (hash.original_data_type == MTL::DataTypeFloat) {
defines += std::string("typedef float original_real;");
defines += "\n";
defines += std::string("typedef float4 original_real4;");
defines += "\n";
encoder->setComputePipelineState(pipeline.get());

if (tensors[0] == tensors[1]) {
encoder->useResource(tensors[0], MTL::ResourceUsageRead | MTL::ResourceUsageWrite);
} else {
defines += std::string("typedef half original_real;");
defines += "\n";
defines += std::string("typedef half4 original_real4;");
defines += "\n";
encoder->useResource(tensors[0], MTL::ResourceUsageRead);
encoder->useResource(tensors[1], MTL::ResourceUsageWrite);
}

unsigned int count;
if (hash.length % 4 == 0) {
count = hash.length / 4;
if (params.length % 4 == 0) {
count = params.length / 4;
} else {
count = hash.length;
}
if (hash.length % (4 * 256) != 0) {
defines += "constant uint count = ";
defines += std::to_string(count) + ";";
defines += "\n";
count = params.length;
}
this->group_size = MTL::Size(256, 1, 1);
const int num_blocks = (count + 255) / 256;
this->grid_size = MTL::Size(num_blocks, 1, 1);
MTL::Size gridSize = MTL::Size(num_blocks, 1, 1);
CCV_NNC_MFA_PRECONDITION(gridSize.depth > 0);
encoder->dispatchThreadgroups(gridSize, kernel->threadgroupSize);

auto constants = NS::TransferPtr(MTL::FunctionConstantValues::alloc()->init());
NS::SharedPtr<MTL::ComputePipelineState>* pso = &cast_pso;

std::string source = defines;
if (METAL_LOG_LEVEL(context) >= 4) {
std::cerr << source << std::endl;
}
source += shader;

NS::Error *error = nullptr;
auto swift_source = NS::String::string(source.c_str(),
NS::UTF8StringEncoding);
auto library = NS::TransferPtr(context->device->newLibrary(swift_source, nullptr, &error));
if (!library) {
CCV_NNC_MFA_CHECK_ERROR(error)
}

auto swift_name = NS::String::string("cast", NS::UTF8StringEncoding);
auto function = NS::TransferPtr(library->newFunction(swift_name, constants.get(), &error));
if (!function) {
CCV_NNC_MFA_CHECK_ERROR(error)
}

*pso = NS::TransferPtr(context->device->newComputePipelineState(function.get(), &error));
if (!*pso) {
CCV_NNC_MFA_CHECK_ERROR(error)
}

pool->drain();
command_batch->finishCommand(encoder);
}
43 changes: 0 additions & 43 deletions lib/nnc/mfa/ccv_nnc_mfa_cast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,49 +8,6 @@ typedef struct {
} ccv_nnc_mfa_cast_params_t;

#ifdef __cplusplus
#include "nnc/mfa/3rdparty/metal-cpp/Dispatch.hpp"
#include "nnc/mfa/3rdparty/metal-cpp/Metal.hpp"
#include <simd/simd.h>

namespace ccv {
namespace nnc {
namespace mfa {
namespace cast {

class hash {
public:
uint64_t original_data_type;
uint64_t data_type;
uint32_t length;

hash(ccv_nnc_mfa_cast_params_t);

bool operator==(const hash& rhs) const;
};

class pipeline {
public:
NS::SharedPtr<MTL::ComputePipelineState> cast_pso;

MTL::Size grid_size;
MTL::Size group_size;

pipeline(context* context, hash hash);
};

} // namespace cast
} // namespace mfa
} // namespace nnc
} // namespace ccv

std::ostream& operator<<(std::ostream& os, const ccv::nnc::mfa::cast::hash& hash);

template<>
struct std::hash<ccv::nnc::mfa::cast::hash>
{
std::size_t operator()(const ccv::nnc::mfa::cast::hash& hash) const noexcept;
};

extern "C" {
#endif // __cplusplus

Expand Down
35 changes: 0 additions & 35 deletions lib/nnc/mfa/ccv_nnc_mfa_cmul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,41 +11,6 @@ typedef struct {
} ccv_nnc_mfa_cmul_params_t;

#ifdef __cplusplus
#include "nnc/mfa/3rdparty/metal-cpp/Dispatch.hpp"
#include "nnc/mfa/3rdparty/metal-cpp/Metal.hpp"
#include <simd/simd.h>

namespace ccv {
namespace nnc {
namespace mfa {
namespace cmul {

class hash {
public:
uint64_t data_type;
uint32_t astride[3];
uint32_t bstride[3];
uint32_t cstride[3];
uint32_t dim[4];

hash(ccv_nnc_mfa_cmul_params_t);
};

class pipeline {
public:
NS::SharedPtr<MTL::ComputePipelineState> cmul_pso;

MTL::Size grid_size;
MTL::Size group_size;

pipeline(context* context, hash hash);
};

} // namespace cmul
} // namespace mfa
} // namespace nnc
} // namespace ccv

extern "C" {
#endif // __cplusplus

Expand Down
35 changes: 0 additions & 35 deletions lib/nnc/mfa/ccv_nnc_mfa_gelu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,41 +9,6 @@ typedef struct {
} ccv_nnc_mfa_gelu_params_t;

#ifdef __cplusplus
#include "nnc/mfa/3rdparty/metal-cpp/Dispatch.hpp"
#include "nnc/mfa/3rdparty/metal-cpp/Metal.hpp"
#include <simd/simd.h>

namespace ccv {
namespace nnc {
namespace mfa {
namespace gelu {

class hash {
public:
uint64_t data_type;
uint32_t astride[3];
uint32_t bstride[3];
uint32_t cstride[3];
uint32_t dim[4];

hash(ccv_nnc_mfa_gelu_params_t);
};

class pipeline {
public:
NS::SharedPtr<MTL::ComputePipelineState> gelu_pso;

MTL::Size grid_size;
MTL::Size group_size;

pipeline(context* context, hash hash);
};

} // namespace gelu
} // namespace mfa
} // namespace nnc
} // namespace ccv

extern "C" {
#endif // __cplusplus

Expand Down
3 changes: 0 additions & 3 deletions lib/nnc/mfa/ccv_nnc_mfa_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ typedef struct {
} ccv_nnc_mfa_gemm_params_t;

#ifdef __cplusplus
#include "nnc/mfa/3rdparty/metal-cpp/Dispatch.hpp"
#include "nnc/mfa/3rdparty/metal-cpp/Metal.hpp"

extern "C" {
#endif // __cplusplus

Expand Down
Loading

0 comments on commit c338ac3

Please sign in to comment.