diff --git a/lib/nnc/ccv_nnc_palettize.c b/lib/nnc/ccv_nnc_palettize.c index 15ad3e28a..c7f55e6f2 100644 --- a/lib/nnc/ccv_nnc_palettize.c +++ b/lib/nnc/ccv_nnc_palettize.c @@ -964,6 +964,7 @@ void ccv_nnc_depalettize(const void* input, const int datatype, const int memory #ifdef HAVE_CUDA ccv_nnc_compat_depalettize(input, datatype, input_length, qbits, number_in_blocks, output, output_length, 0); #elif defined(HAVE_MPS) + ccv_nnc_mps_depalettize(input, datatype, input_length, qbits, number_in_blocks, output, output_length, 0); #else assert(memory_type == CCV_TENSOR_CPU_MEMORY); #endif diff --git a/lib/nnc/mfa/ccv_nnc_mfa.cpp b/lib/nnc/mfa/ccv_nnc_mfa.cpp index 644842c0e..135d40dd0 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa.cpp +++ b/lib/nnc/mfa/ccv_nnc_mfa.cpp @@ -94,6 +94,12 @@ void mfa::cache::prepare _mfa_cache_prepare(&map, context, hash); } +template <> +void mfa::cache::prepare(mfa::context* context, mfa::depalettize::hash hash) +{ + _mfa_cache_prepare(&map, context, hash); +} + mfa::context::context(MTL::Device* device) { auto* pool = NS::AutoreleasePool::alloc()->init(); diff --git a/lib/nnc/mfa/ccv_nnc_mfa.hpp b/lib/nnc/mfa/ccv_nnc_mfa.hpp index 8dec4dfc8..29bdf51d5 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa.hpp +++ b/lib/nnc/mfa/ccv_nnc_mfa.hpp @@ -6,6 +6,7 @@ #include "ccv_nnc_mfa_attention.hpp" #include "ccv_nnc_mfa_gemm.hpp" #include "ccv_nnc_mfa_normalization.hpp" +#include "ccv_nnc_mfa_depalettize.hpp" #ifdef __cplusplus #include "nnc/mfa/3rdparty/metal-cpp/Dispatch.hpp" @@ -44,6 +45,7 @@ class context { cache attention_cache; cache gemm_cache; cache normalization_cache; + cache depalettize_cache; MTL::Buffer* request_scratch(uint64_t size); }; diff --git a/lib/nnc/mfa/ccv_nnc_mfa_depalettize.cpp b/lib/nnc/mfa/ccv_nnc_mfa_depalettize.cpp new file mode 100644 index 000000000..f62815536 --- /dev/null +++ b/lib/nnc/mfa/ccv_nnc_mfa_depalettize.cpp @@ -0,0 +1,225 @@ +#include "ccv_nnc_mfa.hpp" +#include "ccv_nnc_mfa_hash.hpp" +#include +using namespace ccv::nnc; + +#include + +// MARK: - C + +void ccv_nnc_mfa_prepare_depalettize(mfa::context* context, ccv_nnc_mfa_depalettize_params_t params) +{ + context->depalettize_cache.prepare(context, mfa::depalettize::hash(params)); +} + +void ccv_nnc_mfa_encode_depalettize(ccv_nnc_mfa_context_t* context, ccv_nnc_mfa_depalettize_params_t params, mtl_command_batch_t* command_batch, mtl_buffer_t** tensors, size_t* tensor_offsets) +{ + mfa::depalettize::hash hash(params); + auto iterator = context->depalettize_cache.map.find(hash); + if (iterator == context->depalettize_cache.map.end()) { + mfa::precondition_failure("Depalettize hash not cached.", __LINE__, __FILE__, __FUNCTION__); + } + + auto* pipeline = iterator->second; + auto encoder = command_batch->startCommand(); + + int num_tensors = 0; + while (tensors[num_tensors] != nullptr) { + encoder->setBuffer(tensors[num_tensors], tensor_offsets[num_tensors], NS::UInteger(num_tensors)); + num_tensors += 1; + } + CCV_NNC_MFA_PRECONDITION(num_tensors == 2); + + encoder->setComputePipelineState(pipeline->depalettize_pso.get()); + encoder->useResource(tensors[0], MTL::ResourceUsageRead); + encoder->useResource(tensors[1], MTL::ResourceUsageWrite); + + auto grid_size = pipeline->grid_size; + grid_size.depth = 1; + CCV_NNC_MFA_PRECONDITION(grid_size.depth > 0); + encoder->dispatchThreadgroups(grid_size, pipeline->group_size); + command_batch->finishCommand(encoder); +} + +// MARK: - C++ + +mfa::depalettize::hash::hash(ccv_nnc_mfa_depalettize_params_t params) { + data_type = params.data_type; + qbits = params.qbits; + number_in_blocks = params.number_in_blocks; + length = params.length; +} + +bool mfa::depalettize::hash::operator==(const mfa::depalettize::hash& hash) const { + return + (data_type == hash.data_type) && + (qbits == hash.qbits) && + (number_in_blocks == hash.number_in_blocks) && + (length == hash.length); +} + +std::ostream& operator<<(std::ostream& os, const mfa::depalettize::hash& hash) { + os << "mfa::depalettize::hash {"; + os << " .data_type = " << hash.data_type << ','; + os << " .qbits = " << hash.qbits << ','; + os << " .number_in_blocks = " << hash.number_in_blocks << ','; + os << " .length = " << hash.length << " "; + os << "}"; + return os; +} + +std::size_t std::hash::operator()(const mfa::depalettize::hash& hash) const noexcept { + std::size_t seed = 0; + using namespace mfa::hash; + combine_64(seed, hash.data_type); + combine_64(seed, pack_64(simd::uint2 { (unsigned int)hash.qbits, (unsigned int)hash.number_in_blocks })); + combine_64(seed, hash.length); + return seed; +} + +mfa::depalettize::pipeline::pipeline(mfa::context* context, mfa::depalettize::hash hash) { + // FlashNorm not supported for group depalettize yet. + CCV_NNC_MFA_PRECONDITION((hash.data_type == MTL::DataTypeFloat) || (hash.data_type == MTL::DataTypeHalf)) + + auto* pool = NS::AutoreleasePool::alloc()->init(); + + std::string shader; + if (hash.qbits == 6) { + shader = R"( +#include +using namespace metal; + +kernel void depalettize( + device uchar *source [[buffer(0)]], + device real *destination [[buffer(1)]], + + uint3 tgid [[threadgroup_position_in_grid]], + ushort lid [[thread_index_in_threadgroup]] +) { + device const uchar *ui0 = source + (sizeof(real) * palette_size + number_in_blocks * 3) * tgid.y; + threadgroup real palette[palette_size]; + if (lid < palette_size) { + palette[lid] = ((device real*)ui0)[lid]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + destination += number_in_blocks * 4 * tgid.y + tgid.x * threadgroup_size * num_repeats * 4; + device const uchar *ui1 = ui0 + sizeof(real) * palette_size + tgid.x * threadgroup_size * num_repeats * 3; + #pragma clang loop unroll(full) + for (uint k = 0; k < num_repeats; k++) + { + const uint8_t u0 = ui1[(k * threadgroup_size + lid) * 3]; + const uint8_t u1 = ui1[(k * threadgroup_size + lid) * 3 + 1]; + const uint8_t u2 = ui1[(k * threadgroup_size + lid) * 3 + 2]; + destination[(k * threadgroup_size + lid) * 4] = palette[u0 >> 2]; + destination[(k * threadgroup_size + lid) * 4 + 1] = palette[((u0 & 3) << 4) | (u1 >> 4)]; + destination[(k * threadgroup_size + lid) * 4 + 2] = palette[((u1 & 15) << 2) | (u2 >> 6)]; + destination[(k * threadgroup_size + lid) * 4 + 3] = palette[u2 & 63]; + } +} + )"; + } else if (hash.qbits == 8) { + shader = R"( +#include +using namespace metal; + +kernel void depalettize( + device uchar *source [[buffer(0)]], + device real *destination [[buffer(1)]], + + uint3 tgid [[threadgroup_position_in_grid]], + ushort lid [[thread_index_in_threadgroup]] +) { + device const uchar *ui0 = source + (sizeof(real) * palette_size + number_in_blocks) * tgid.y; + threadgroup real palette[palette_size]; + if (lid < palette_size) { + palette[lid] = ((device real*)ui0)[lid]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + destination += number_in_blocks * tgid.y + tgid.x * threadgroup_size * num_repeats * 4; + device const uint *ui1 = (device uint*)(ui0 + sizeof(real) * palette_size) + tgid.x * threadgroup_size * num_repeats; + #pragma clang loop unroll(full) + for (uint k = 0; k < num_repeats; k++) + { + const uint u0 = ui1[k * threadgroup_size + lid]; + destination[(k * threadgroup_size + lid) * 4] = palette[u0 & 0xff]; + destination[(k * threadgroup_size + lid) * 4 + 1] = palette[(u0 >> 8) & 0xff]; + destination[(k * threadgroup_size + lid) * 4 + 2] = palette[(u0 >> 16) & 0xff]; + destination[(k * threadgroup_size + lid) * 4 + 3] = palette[u0 >> 24]; + } +} + )"; + } + + std::string defines = ""; + if (hash.data_type == MTL::DataTypeFloat) { + defines += std::string("typedef float real;"); + defines += "\n"; + } else { + defines += std::string("typedef half real;"); + defines += "\n"; + } + + uint16_t threadgroup_size = 256; + defines += "constant ushort threadgroup_size = "; + defines += std::to_string(threadgroup_size) + ";"; + defines += "\n"; + this->group_size = MTL::Size(threadgroup_size, 1, 1); + + if (hash.qbits == 6) { + defines += "constant ushort palette_size = 64;\n"; + + defines += "constant uint number_in_blocks = "; + defines += std::to_string(hash.number_in_blocks / 4) + ";"; + defines += "\n"; + const int num_blocks = hash.length / hash.number_in_blocks; + const int repeat_4 = hash.number_in_blocks / (256 * 4); + + defines += "constant uint num_repeats = "; + defines += std::to_string(1) + ";"; + defines += "\n"; + this->grid_size = MTL::Size(repeat_4, num_blocks, 1); + } else if (hash.qbits == 8) { + defines += "constant ushort palette_size = 256;\n"; + + defines += "constant uint number_in_blocks = "; + defines += std::to_string(hash.number_in_blocks) + ";"; + defines += "\n"; + const int num_blocks = hash.length / hash.number_in_blocks; + const int repeat_4 = hash.number_in_blocks / (256 * 4 * 2); + + defines += "constant uint num_repeats = "; + defines += std::to_string(2) + ";"; + defines += "\n"; + this->grid_size = MTL::Size(repeat_4, num_blocks, 1); + } + + auto constants = NS::TransferPtr(MTL::FunctionConstantValues::alloc()->init()); + NS::SharedPtr* pso = &depalettize_pso; + + std::string source = defines; + if (METAL_LOG_LEVEL(context) >= 4) { + std::cerr << source << std::endl; + } + source += shader; + + NS::Error *error = nullptr; + auto swift_source = NS::String::string(source.c_str(), + NS::UTF8StringEncoding); + auto library = NS::TransferPtr(context->device->newLibrary(swift_source, nullptr, &error)); + if (!library) { + CCV_NNC_MFA_CHECK_ERROR(error) + } + + auto swift_name = NS::String::string("depalettize", NS::UTF8StringEncoding); + auto function = NS::TransferPtr(library->newFunction(swift_name, constants.get(), &error)); + if (!function) { + CCV_NNC_MFA_CHECK_ERROR(error) + } + + *pso = NS::TransferPtr(context->device->newComputePipelineState(function.get(), &error)); + if (!*pso) { + CCV_NNC_MFA_CHECK_ERROR(error) + } + + pool->drain(); +} diff --git a/lib/nnc/mfa/ccv_nnc_mfa_depalettize.hpp b/lib/nnc/mfa/ccv_nnc_mfa_depalettize.hpp new file mode 100644 index 000000000..4257ac536 --- /dev/null +++ b/lib/nnc/mfa/ccv_nnc_mfa_depalettize.hpp @@ -0,0 +1,66 @@ +#ifndef GUARD_ccv_nnc_mfa_depalettize_hpp +#define GUARD_ccv_nnc_mfa_depalettize_hpp + +typedef struct { + uint64_t data_type; + int qbits; + int number_in_blocks; + uint64_t length; +} ccv_nnc_mfa_depalettize_params_t; + +#ifdef __cplusplus +#include "nnc/mfa/3rdparty/metal-cpp/Dispatch.hpp" +#include "nnc/mfa/3rdparty/metal-cpp/Metal.hpp" +#include + +namespace ccv { +namespace nnc { +namespace mfa { +namespace depalettize { + +class hash { +public: + uint64_t data_type; + int qbits; + int number_in_blocks; + uint64_t length; + + hash(ccv_nnc_mfa_depalettize_params_t); + + bool operator==(const hash& rhs) const; +}; + +class pipeline { +public: + NS::SharedPtr depalettize_pso; + + MTL::Size grid_size; + MTL::Size group_size; + + pipeline(context* context, hash hash); +}; + +} // namespace depalettize +} // namespace mfa +} // namespace nnc +} // namespace ccv + +std::ostream& operator<<(std::ostream& os, const ccv::nnc::mfa::depalettize::hash& hash); + +template<> +struct std::hash +{ + std::size_t operator()(const ccv::nnc::mfa::depalettize::hash& hash) const noexcept; +}; + +extern "C" { +#endif // __cplusplus + +void ccv_nnc_mfa_prepare_depalettize(ccv_nnc_mfa_context_t* context, ccv_nnc_mfa_depalettize_params_t params); +void ccv_nnc_mfa_encode_depalettize(ccv_nnc_mfa_context_t* context, ccv_nnc_mfa_depalettize_params_t params, mtl_command_batch_t* command_batch, mtl_buffer_t** tensors, size_t* tensor_offsets); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif diff --git a/lib/nnc/mfa/makefile b/lib/nnc/mfa/makefile index c7efef61f..0d3a76c95 100644 --- a/lib/nnc/mfa/makefile +++ b/lib/nnc/mfa/makefile @@ -2,7 +2,7 @@ include ../../config.mk CFLAGS := -std=c++17 -O3 -Wall -I"../../" $(CFLAGS) -SRCS := Metal.cpp ccv_nnc_mfa.cpp ccv_nnc_mfa_attention.cpp ccv_nnc_mfa_error.cpp ccv_nnc_mfa_gemm.cpp ccv_nnc_mfa_normalization.cpp 3rdparty/metal-cpp/Dispatch.cpp +SRCS := Metal.cpp ccv_nnc_mfa.cpp ccv_nnc_mfa_attention.cpp ccv_nnc_mfa_error.cpp ccv_nnc_mfa_gemm.cpp ccv_nnc_mfa_normalization.cpp ccv_nnc_mfa_depalettize.cpp 3rdparty/metal-cpp/Dispatch.cpp SRC_OBJS := $(patsubst %.c,%.o,$(patsubst %.cpp,%.o,$(SRCS))) diff --git a/lib/nnc/mps/ccv_nnc_mps.h b/lib/nnc/mps/ccv_nnc_mps.h index 42d88e2e6..b8177d715 100644 --- a/lib/nnc/mps/ccv_nnc_mps.h +++ b/lib/nnc/mps/ccv_nnc_mps.h @@ -32,6 +32,7 @@ void ccv_nnc_deinit_stream_signal(ccv_nnc_stream_signal_t* const signal); CCV_WARN_UNUSED(int) ccv_nnc_gpu_device_count(void); void ccv_nnc_mps_unbounded_command_buffers(int state); void ccv_nnc_mps_clear_graph_executable_cache(void); +void ccv_nnc_mps_depalettize(const void* input, const int datatype, const size_t input_length, const int qbits, const int number_in_blocks, void* output, const size_t output_length, void* const command_buffer); #ifdef __OBJC__ diff --git a/lib/nnc/mps/ccv_nnc_palettize.m b/lib/nnc/mps/ccv_nnc_palettize.m new file mode 100644 index 000000000..69f37b7b5 --- /dev/null +++ b/lib/nnc/mps/ccv_nnc_palettize.m @@ -0,0 +1,44 @@ +#include "ccv_nnc_mps.h" +#include "ccv_internal.h" +#include "nnc/ccv_nnc_internal.h" +#include "nnc/ccv_nnc_easy.h" + +void ccv_nnc_mps_depalettize(const void* input, const int datatype, const size_t input_length, const int qbits, const int number_in_blocks, void* output, const size_t output_length, void* const command_buffer) +{ + uint32_t mtl_data_type = UINT32_MAX; + switch (datatype) { + case CCV_16F: { + mtl_data_type = 16; + break; + } + case CCV_32F: { + mtl_data_type = 3; + break; + } + default: { + break; + } + } + ccv_nnc_mfa_depalettize_params_t params = { + .data_type = mtl_data_type, + .qbits = (uint32_t)qbits, + .number_in_blocks = (uint32_t)number_in_blocks, + .length = (uint64_t)output_length, + }; + ccv_nnc_mfa_context_t* context = ccv_nnc_default_mfa_context(); + + ccv_nnc_mfa_prepare_depalettize(context, params); + + mtl_command_batch_t* command_batch = ccv_nnc_stream_context_start_command_batch(0); + mtl_buffer_t* tensors[3] = { + (mtl_buffer_t*)input, // A + (mtl_buffer_t*)output, // B + NULL, + }; + size_t tensor_offsets[2] = { + 0, // A offset + 0, // B offset + }; + ccv_nnc_mfa_encode_depalettize(context, params, command_batch, tensors, tensor_offsets); + ccv_nnc_stream_context_finish_command_batch(0, command_batch); +} diff --git a/lib/nnc/mps/makefile b/lib/nnc/mps/makefile index cbba0a9c9..a5e41211f 100644 --- a/lib/nnc/mps/makefile +++ b/lib/nnc/mps/makefile @@ -2,7 +2,7 @@ include ../../config.mk CFLAGS := -O3 -Wall -I"../../" $(CFLAGS) -SRCS := ccv_nnc_mps.m +SRCS := ccv_nnc_mps.m ccv_nnc_palettize.m SRC_OBJS := $(patsubst %.c,%.o,$(patsubst %.m,%.o,$(SRCS))) diff --git a/test/int/nnc/palettize.tests.c b/test/int/nnc/palettize.tests.c index 4bc15df15..d855ca2f9 100644 --- a/test/int/nnc/palettize.tests.c +++ b/test/int/nnc/palettize.tests.c @@ -647,27 +647,27 @@ TEST_CASE("quantize float to 6-bit and dequantize on GPU losslessly, fast path") TEST_CASE("quantize half-precision to 6-bit and dequantize on GPU losslessly, fast path") { - GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_DATA_TRANSFER_FORWARD, CCV_NNC_BACKEND_GPU_REF)); + GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_DATA_TRANSFER_FORWARD, CCV_NNC_BACKEND_GPU_REF) || ccv_nnc_cmd_ok(CCV_NNC_DATA_TRANSFER_FORWARD, CCV_NNC_BACKEND_MPS)); float lut_f32[64]; int i; for (i = 0; i < 64; i++) lut_f32[i] = (float)i; uint16_t lut[64]; ccv_float_to_half_precision(lut_f32, lut, 64); - uint16_t* const values = ccmalloc(sizeof(uint16_t) * 2840); - for (i = 0; i < 2840; i++) + uint16_t* const values = ccmalloc(sizeof(uint16_t) * 8192); + for (i = 0; i < 8192; i++) values[i] = lut[i % 64]; - ccv_nnc_tensor_t* tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NCHW(32F, (2130 + 6 * 64 * 2 + 3) / 4), 0); + ccv_nnc_tensor_t* tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NCHW(32F, (6144 + 2 * 64 * 2 + 3) / 4), 0); uint8_t* compressed = tensor->data.u8; - const size_t output_size = ccv_nnc_palettize(values, CCV_16F, CCV_TENSOR_CPU_MEMORY, 2840, 6, 512, compressed, 2130 + 6 * 64 * 2); - REQUIRE_EQ(output_size, 2130 + 6 * 64 * 2, "output size should match"); - ccv_nnc_tensor_t* g_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NCHW(000, 32F, (2130 + 6 * 64 * 2 + 3) / 4), 0); + const size_t output_size = ccv_nnc_palettize(values, CCV_16F, CCV_TENSOR_CPU_MEMORY, 8192, 6, 4096, compressed, 6144 + 2 * 64 * 2); + REQUIRE_EQ(output_size, 6144 + 2 * 64 * 2, "output size should match"); + ccv_nnc_tensor_t* g_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NCHW(000, 32F, (6144 + 2 * 64 * 2 + 3) / 4), 0); ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(tensor), TENSOR_LIST(g_tensor), 0); - ccv_nnc_tensor_t* gv_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NCHW(000, 16F, 2840), 0); - ccv_nnc_depalettize(g_tensor->data.u8, CCV_16F, CCV_TENSOR_GPU_MEMORY, output_size, 6, 512, gv_tensor->data.u8, 2840); - ccv_nnc_tensor_t* v_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NCHW(16F, 2840), 0); + ccv_nnc_tensor_t* gv_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NCHW(000, 16F, 8192), 0); + ccv_nnc_depalettize(g_tensor->data.u8, CCV_16F, CCV_TENSOR_GPU_MEMORY, output_size, 6, 4096, gv_tensor->data.u8, 8192); + ccv_nnc_tensor_t* v_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NCHW(16F, 8192), 0); ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(gv_tensor), TENSOR_LIST(v_tensor), 0); - REQUIRE_ARRAY_EQ(uint16_t, values, v_tensor->data.f16, 2840, "should be lossless"); + REQUIRE_ARRAY_EQ(uint16_t, values, v_tensor->data.f16, 8192, "should be lossless"); ccfree(values); ccv_nnc_tensor_free(tensor); ccv_nnc_tensor_free(g_tensor); @@ -791,25 +791,25 @@ TEST_CASE("quantize double to 8-bit and dequantize on GPU losslessly, fast path" TEST_CASE("quantize float to 8-bit and dequantize on GPU losslessly, fast path") { - GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_DATA_TRANSFER_FORWARD, CCV_NNC_BACKEND_GPU_REF)); + GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_DATA_TRANSFER_FORWARD, CCV_NNC_BACKEND_GPU_REF) || ccv_nnc_cmd_ok(CCV_NNC_DATA_TRANSFER_FORWARD, CCV_NNC_BACKEND_MPS)); float lut[256]; int i; for (i = 0; i < 256; i++) lut[i] = (float)i; - float* const values = ccmalloc(sizeof(float) * 2840); - for (i = 0; i < 2840; i++) + float* const values = ccmalloc(sizeof(float) * 8192); + for (i = 0; i < 8192; i++) values[i] = lut[i % 256]; - ccv_nnc_tensor_t* tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NCHW(32F, (2840 + 3 * 256 * 4 + 3) / 4), 0); + ccv_nnc_tensor_t* tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NCHW(32F, (8192 + 2 * 256 * 4 + 3) / 4), 0); uint8_t* compressed = tensor->data.u8; - const size_t output_size = ccv_nnc_palettize(values, CCV_32F, CCV_TENSOR_CPU_MEMORY, 2840, 8, 1280, compressed, 2840 + 3 * 256 * 4); - REQUIRE_EQ(output_size, 2840 + 3 * 256 * 4, "output size should match"); - ccv_nnc_tensor_t* g_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NCHW(000, 32F, (2840 + 3 * 256 * 4 + 3) / 4), 0); + const size_t output_size = ccv_nnc_palettize(values, CCV_32F, CCV_TENSOR_CPU_MEMORY, 8192, 8, 4096, compressed, 8192 + 2 * 256 * 4); + REQUIRE_EQ(output_size, 8192 + 2 * 256 * 4, "output size should match"); + ccv_nnc_tensor_t* g_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NCHW(000, 32F, (8192 + 2 * 256 * 4 + 3) / 4), 0); ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(tensor), TENSOR_LIST(g_tensor), 0); - ccv_nnc_tensor_t* gv_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NCHW(000, 32F, 2840), 0); - ccv_nnc_depalettize(g_tensor->data.u8, CCV_32F, CCV_TENSOR_GPU_MEMORY, output_size, 8, 1280, gv_tensor->data.u8, 2840); - ccv_nnc_tensor_t* v_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NCHW(32F, 2840), 0); + ccv_nnc_tensor_t* gv_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NCHW(000, 32F, 8192), 0); + ccv_nnc_depalettize(g_tensor->data.u8, CCV_32F, CCV_TENSOR_GPU_MEMORY, output_size, 8, 4096, gv_tensor->data.u8, 8192); + ccv_nnc_tensor_t* v_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NCHW(32F, 8192), 0); ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(gv_tensor), TENSOR_LIST(v_tensor), 0); - REQUIRE_ARRAY_EQ(float, values, v_tensor->data.f32, 2840, "should be lossless"); + REQUIRE_ARRAY_EQ(float, values, v_tensor->data.f32, 8192, "should be lossless"); ccfree(values); ccv_nnc_tensor_free(tensor); ccv_nnc_tensor_free(g_tensor);