Skip to content

Commit

Permalink
Add specialized implementation for depalettize q6p / q8p in GPU.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Sep 8, 2023
1 parent 4248001 commit 5a05d9f
Show file tree
Hide file tree
Showing 10 changed files with 369 additions and 24 deletions.
1 change: 1 addition & 0 deletions lib/nnc/ccv_nnc_palettize.c
Original file line number Diff line number Diff line change
Expand Up @@ -964,6 +964,7 @@ void ccv_nnc_depalettize(const void* input, const int datatype, const int memory
#ifdef HAVE_CUDA
ccv_nnc_compat_depalettize(input, datatype, input_length, qbits, number_in_blocks, output, output_length, 0);
#elif defined(HAVE_MPS)
ccv_nnc_mps_depalettize(input, datatype, input_length, qbits, number_in_blocks, output, output_length, 0);
#else
assert(memory_type == CCV_TENSOR_CPU_MEMORY);
#endif
Expand Down
6 changes: 6 additions & 0 deletions lib/nnc/mfa/ccv_nnc_mfa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ void mfa::cache<mfa::normalization::hash, mfa::normalization::pipeline>::prepare
_mfa_cache_prepare(&map, context, hash);
}

template <>
void mfa::cache<mfa::depalettize::hash, mfa::depalettize::pipeline>::prepare(mfa::context* context, mfa::depalettize::hash hash)
{
_mfa_cache_prepare(&map, context, hash);
}

mfa::context::context(MTL::Device* device)
{
auto* pool = NS::AutoreleasePool::alloc()->init();
Expand Down
2 changes: 2 additions & 0 deletions lib/nnc/mfa/ccv_nnc_mfa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "ccv_nnc_mfa_attention.hpp"
#include "ccv_nnc_mfa_gemm.hpp"
#include "ccv_nnc_mfa_normalization.hpp"
#include "ccv_nnc_mfa_depalettize.hpp"

#ifdef __cplusplus
#include "nnc/mfa/3rdparty/metal-cpp/Dispatch.hpp"
Expand Down Expand Up @@ -44,6 +45,7 @@ class context {
cache<attention::hash, attention::pipeline> attention_cache;
cache<gemm::hash, gemm::pipeline> gemm_cache;
cache<normalization::hash, normalization::pipeline> normalization_cache;
cache<depalettize::hash, depalettize::pipeline> depalettize_cache;

MTL::Buffer* request_scratch(uint64_t size);
};
Expand Down
225 changes: 225 additions & 0 deletions lib/nnc/mfa/ccv_nnc_mfa_depalettize.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
#include "ccv_nnc_mfa.hpp"
#include "ccv_nnc_mfa_hash.hpp"
#include <simd/simd.h>
using namespace ccv::nnc;

#include <string>

// MARK: - C

void ccv_nnc_mfa_prepare_depalettize(mfa::context* context, ccv_nnc_mfa_depalettize_params_t params)
{
context->depalettize_cache.prepare(context, mfa::depalettize::hash(params));
}

void ccv_nnc_mfa_encode_depalettize(ccv_nnc_mfa_context_t* context, ccv_nnc_mfa_depalettize_params_t params, mtl_command_batch_t* command_batch, mtl_buffer_t** tensors, size_t* tensor_offsets)
{
mfa::depalettize::hash hash(params);
auto iterator = context->depalettize_cache.map.find(hash);
if (iterator == context->depalettize_cache.map.end()) {
mfa::precondition_failure("Depalettize hash not cached.", __LINE__, __FILE__, __FUNCTION__);
}

auto* pipeline = iterator->second;
auto encoder = command_batch->startCommand();

int num_tensors = 0;
while (tensors[num_tensors] != nullptr) {
encoder->setBuffer(tensors[num_tensors], tensor_offsets[num_tensors], NS::UInteger(num_tensors));
num_tensors += 1;
}
CCV_NNC_MFA_PRECONDITION(num_tensors == 2);

encoder->setComputePipelineState(pipeline->depalettize_pso.get());
encoder->useResource(tensors[0], MTL::ResourceUsageRead);
encoder->useResource(tensors[1], MTL::ResourceUsageWrite);

auto grid_size = pipeline->grid_size;
grid_size.depth = 1;
CCV_NNC_MFA_PRECONDITION(grid_size.depth > 0);
encoder->dispatchThreadgroups(grid_size, pipeline->group_size);
command_batch->finishCommand(encoder);
}

// MARK: - C++

mfa::depalettize::hash::hash(ccv_nnc_mfa_depalettize_params_t params) {
data_type = params.data_type;
qbits = params.qbits;
number_in_blocks = params.number_in_blocks;
length = params.length;
}

bool mfa::depalettize::hash::operator==(const mfa::depalettize::hash& hash) const {
return
(data_type == hash.data_type) &&
(qbits == hash.qbits) &&
(number_in_blocks == hash.number_in_blocks) &&
(length == hash.length);
}

std::ostream& operator<<(std::ostream& os, const mfa::depalettize::hash& hash) {
os << "mfa::depalettize::hash {";
os << " .data_type = " << hash.data_type << ',';
os << " .qbits = " << hash.qbits << ',';
os << " .number_in_blocks = " << hash.number_in_blocks << ',';
os << " .length = " << hash.length << " ";
os << "}";
return os;
}

std::size_t std::hash<mfa::depalettize::hash>::operator()(const mfa::depalettize::hash& hash) const noexcept {
std::size_t seed = 0;
using namespace mfa::hash;
combine_64(seed, hash.data_type);
combine_64(seed, pack_64(simd::uint2 { (unsigned int)hash.qbits, (unsigned int)hash.number_in_blocks }));
combine_64(seed, hash.length);
return seed;
}

mfa::depalettize::pipeline::pipeline(mfa::context* context, mfa::depalettize::hash hash) {
// FlashNorm not supported for group depalettize yet.
CCV_NNC_MFA_PRECONDITION((hash.data_type == MTL::DataTypeFloat) || (hash.data_type == MTL::DataTypeHalf))

auto* pool = NS::AutoreleasePool::alloc()->init();

std::string shader;
if (hash.qbits == 6) {
shader = R"(
#include <metal_stdlib>
using namespace metal;
kernel void depalettize(
device uchar *source [[buffer(0)]],
device real *destination [[buffer(1)]],
uint3 tgid [[threadgroup_position_in_grid]],
ushort lid [[thread_index_in_threadgroup]]
) {
device const uchar *ui0 = source + (sizeof(real) * palette_size + number_in_blocks * 3) * tgid.y;
threadgroup real palette[palette_size];
if (lid < palette_size) {
palette[lid] = ((device real*)ui0)[lid];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
destination += number_in_blocks * 4 * tgid.y + tgid.x * threadgroup_size * num_repeats * 4;
device const uchar *ui1 = ui0 + sizeof(real) * palette_size + tgid.x * threadgroup_size * num_repeats * 3;
#pragma clang loop unroll(full)
for (uint k = 0; k < num_repeats; k++)
{
const uint8_t u0 = ui1[(k * threadgroup_size + lid) * 3];
const uint8_t u1 = ui1[(k * threadgroup_size + lid) * 3 + 1];
const uint8_t u2 = ui1[(k * threadgroup_size + lid) * 3 + 2];
destination[(k * threadgroup_size + lid) * 4] = palette[u0 >> 2];
destination[(k * threadgroup_size + lid) * 4 + 1] = palette[((u0 & 3) << 4) | (u1 >> 4)];
destination[(k * threadgroup_size + lid) * 4 + 2] = palette[((u1 & 15) << 2) | (u2 >> 6)];
destination[(k * threadgroup_size + lid) * 4 + 3] = palette[u2 & 63];
}
}
)";
} else if (hash.qbits == 8) {
shader = R"(
#include <metal_stdlib>
using namespace metal;
kernel void depalettize(
device uchar *source [[buffer(0)]],
device real *destination [[buffer(1)]],
uint3 tgid [[threadgroup_position_in_grid]],
ushort lid [[thread_index_in_threadgroup]]
) {
device const uchar *ui0 = source + (sizeof(real) * palette_size + number_in_blocks) * tgid.y;
threadgroup real palette[palette_size];
if (lid < palette_size) {
palette[lid] = ((device real*)ui0)[lid];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
destination += number_in_blocks * tgid.y + tgid.x * threadgroup_size * num_repeats * 4;
device const uint *ui1 = (device uint*)(ui0 + sizeof(real) * palette_size) + tgid.x * threadgroup_size * num_repeats;
#pragma clang loop unroll(full)
for (uint k = 0; k < num_repeats; k++)
{
const uint u0 = ui1[k * threadgroup_size + lid];
destination[(k * threadgroup_size + lid) * 4] = palette[u0 & 0xff];
destination[(k * threadgroup_size + lid) * 4 + 1] = palette[(u0 >> 8) & 0xff];
destination[(k * threadgroup_size + lid) * 4 + 2] = palette[(u0 >> 16) & 0xff];
destination[(k * threadgroup_size + lid) * 4 + 3] = palette[u0 >> 24];
}
}
)";
}

std::string defines = "";
if (hash.data_type == MTL::DataTypeFloat) {
defines += std::string("typedef float real;");
defines += "\n";
} else {
defines += std::string("typedef half real;");
defines += "\n";
}

uint16_t threadgroup_size = 256;
defines += "constant ushort threadgroup_size = ";
defines += std::to_string(threadgroup_size) + ";";
defines += "\n";
this->group_size = MTL::Size(threadgroup_size, 1, 1);

if (hash.qbits == 6) {
defines += "constant ushort palette_size = 64;\n";

defines += "constant uint number_in_blocks = ";
defines += std::to_string(hash.number_in_blocks / 4) + ";";
defines += "\n";
const int num_blocks = hash.length / hash.number_in_blocks;
const int repeat_4 = hash.number_in_blocks / (256 * 4);

defines += "constant uint num_repeats = ";
defines += std::to_string(1) + ";";
defines += "\n";
this->grid_size = MTL::Size(repeat_4, num_blocks, 1);
} else if (hash.qbits == 8) {
defines += "constant ushort palette_size = 256;\n";

defines += "constant uint number_in_blocks = ";
defines += std::to_string(hash.number_in_blocks) + ";";
defines += "\n";
const int num_blocks = hash.length / hash.number_in_blocks;
const int repeat_4 = hash.number_in_blocks / (256 * 4 * 2);

defines += "constant uint num_repeats = ";
defines += std::to_string(2) + ";";
defines += "\n";
this->grid_size = MTL::Size(repeat_4, num_blocks, 1);
}

auto constants = NS::TransferPtr(MTL::FunctionConstantValues::alloc()->init());
NS::SharedPtr<MTL::ComputePipelineState>* pso = &depalettize_pso;

std::string source = defines;
if (METAL_LOG_LEVEL(context) >= 4) {
std::cerr << source << std::endl;
}
source += shader;

NS::Error *error = nullptr;
auto swift_source = NS::String::string(source.c_str(),
NS::UTF8StringEncoding);
auto library = NS::TransferPtr(context->device->newLibrary(swift_source, nullptr, &error));
if (!library) {
CCV_NNC_MFA_CHECK_ERROR(error)
}

auto swift_name = NS::String::string("depalettize", NS::UTF8StringEncoding);
auto function = NS::TransferPtr(library->newFunction(swift_name, constants.get(), &error));
if (!function) {
CCV_NNC_MFA_CHECK_ERROR(error)
}

*pso = NS::TransferPtr(context->device->newComputePipelineState(function.get(), &error));
if (!*pso) {
CCV_NNC_MFA_CHECK_ERROR(error)
}

pool->drain();
}
66 changes: 66 additions & 0 deletions lib/nnc/mfa/ccv_nnc_mfa_depalettize.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#ifndef GUARD_ccv_nnc_mfa_depalettize_hpp
#define GUARD_ccv_nnc_mfa_depalettize_hpp

typedef struct {
uint64_t data_type;
int qbits;
int number_in_blocks;
uint64_t length;
} ccv_nnc_mfa_depalettize_params_t;

#ifdef __cplusplus
#include "nnc/mfa/3rdparty/metal-cpp/Dispatch.hpp"
#include "nnc/mfa/3rdparty/metal-cpp/Metal.hpp"
#include <simd/simd.h>

namespace ccv {
namespace nnc {
namespace mfa {
namespace depalettize {

class hash {
public:
uint64_t data_type;
int qbits;
int number_in_blocks;
uint64_t length;

hash(ccv_nnc_mfa_depalettize_params_t);

bool operator==(const hash& rhs) const;
};

class pipeline {
public:
NS::SharedPtr<MTL::ComputePipelineState> depalettize_pso;

MTL::Size grid_size;
MTL::Size group_size;

pipeline(context* context, hash hash);
};

} // namespace depalettize
} // namespace mfa
} // namespace nnc
} // namespace ccv

std::ostream& operator<<(std::ostream& os, const ccv::nnc::mfa::depalettize::hash& hash);

template<>
struct std::hash<ccv::nnc::mfa::depalettize::hash>
{
std::size_t operator()(const ccv::nnc::mfa::depalettize::hash& hash) const noexcept;
};

extern "C" {
#endif // __cplusplus

void ccv_nnc_mfa_prepare_depalettize(ccv_nnc_mfa_context_t* context, ccv_nnc_mfa_depalettize_params_t params);
void ccv_nnc_mfa_encode_depalettize(ccv_nnc_mfa_context_t* context, ccv_nnc_mfa_depalettize_params_t params, mtl_command_batch_t* command_batch, mtl_buffer_t** tensors, size_t* tensor_offsets);

#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus

#endif
2 changes: 1 addition & 1 deletion lib/nnc/mfa/makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ include ../../config.mk

CFLAGS := -std=c++17 -O3 -Wall -I"../../" $(CFLAGS)

SRCS := Metal.cpp ccv_nnc_mfa.cpp ccv_nnc_mfa_attention.cpp ccv_nnc_mfa_error.cpp ccv_nnc_mfa_gemm.cpp ccv_nnc_mfa_normalization.cpp 3rdparty/metal-cpp/Dispatch.cpp
SRCS := Metal.cpp ccv_nnc_mfa.cpp ccv_nnc_mfa_attention.cpp ccv_nnc_mfa_error.cpp ccv_nnc_mfa_gemm.cpp ccv_nnc_mfa_normalization.cpp ccv_nnc_mfa_depalettize.cpp 3rdparty/metal-cpp/Dispatch.cpp

SRC_OBJS := $(patsubst %.c,%.o,$(patsubst %.cpp,%.o,$(SRCS)))

Expand Down
1 change: 1 addition & 0 deletions lib/nnc/mps/ccv_nnc_mps.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ void ccv_nnc_deinit_stream_signal(ccv_nnc_stream_signal_t* const signal);
CCV_WARN_UNUSED(int) ccv_nnc_gpu_device_count(void);
void ccv_nnc_mps_unbounded_command_buffers(int state);
void ccv_nnc_mps_clear_graph_executable_cache(void);
void ccv_nnc_mps_depalettize(const void* input, const int datatype, const size_t input_length, const int qbits, const int number_in_blocks, void* output, const size_t output_length, void* const command_buffer);

#ifdef __OBJC__

Expand Down
44 changes: 44 additions & 0 deletions lib/nnc/mps/ccv_nnc_palettize.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#include "ccv_nnc_mps.h"
#include "ccv_internal.h"
#include "nnc/ccv_nnc_internal.h"
#include "nnc/ccv_nnc_easy.h"

void ccv_nnc_mps_depalettize(const void* input, const int datatype, const size_t input_length, const int qbits, const int number_in_blocks, void* output, const size_t output_length, void* const command_buffer)
{
uint32_t mtl_data_type = UINT32_MAX;
switch (datatype) {
case CCV_16F: {
mtl_data_type = 16;
break;
}
case CCV_32F: {
mtl_data_type = 3;
break;
}
default: {
break;
}
}
ccv_nnc_mfa_depalettize_params_t params = {
.data_type = mtl_data_type,
.qbits = (uint32_t)qbits,
.number_in_blocks = (uint32_t)number_in_blocks,
.length = (uint64_t)output_length,
};
ccv_nnc_mfa_context_t* context = ccv_nnc_default_mfa_context();

ccv_nnc_mfa_prepare_depalettize(context, params);

mtl_command_batch_t* command_batch = ccv_nnc_stream_context_start_command_batch(0);
mtl_buffer_t* tensors[3] = {
(mtl_buffer_t*)input, // A
(mtl_buffer_t*)output, // B
NULL,
};
size_t tensor_offsets[2] = {
0, // A offset
0, // B offset
};
ccv_nnc_mfa_encode_depalettize(context, params, command_batch, tensors, tensor_offsets);
ccv_nnc_stream_context_finish_command_batch(0, command_batch);
}
Loading

0 comments on commit 5a05d9f

Please sign in to comment.