Skip to content

Commit

Permalink
Support jit dequantization during training.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Sep 27, 2023
1 parent 3fd457b commit 0ce86f0
Show file tree
Hide file tree
Showing 2 changed files with 238 additions and 14 deletions.
204 changes: 193 additions & 11 deletions lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -606,8 +606,80 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
ccv_nnc_mfa_context_t* context = ccv_nnc_default_mfa_context();
const int is_mfa_supported =
ccv_nnc_mfa_context_supported(context) && is_contiguous && is_same_dtype && is_supported_dtype && is_same_batch && !bias && !(ccv_nnc_flags() & CCV_NNC_DISABLE_METAL_FLASH_ATTENTION);

size_t a_data_size = 0;
if (a && dw && CCV_GET_DATA_TYPE(a->info.datatype) == CCV_QX)
{
ccv_nnc_tensor_param_t a_params = a->info;
const int palette_datatype = (a_params.datatype & 0xff) << 12;
ccv_nnc_tensor_param_t depalettize_a_params = a_params;
depalettize_a_params.datatype = palette_datatype;
depalettize_a_params.reserved = 0;
a_data_size = ccv_nnc_tensor_data_size(depalettize_a_params);
}
size_t w_data_size = 0;
if (w && h && CCV_GET_DATA_TYPE(w->info.datatype) == CCV_QX)
{
ccv_nnc_tensor_param_t w_params = w->info;
const int palette_datatype = (w_params.datatype & 0xff) << 12;
ccv_nnc_tensor_param_t depalettize_w_params = w_params;
depalettize_w_params.datatype = palette_datatype;
depalettize_w_params.reserved = 0;
w_data_size = ccv_nnc_tensor_data_size(depalettize_w_params);
}
if (is_mfa_supported)
{
mtl_buffer_t* scratch = 0;
if (a_data_size + w_data_size > 0)
scratch = ccv_nnc_mfa_request_scratch(context, a_data_size + w_data_size);
mtl_buffer_t* a_data = 0;
size_t a_dataof = 0;
ccv_nnc_mfa_depalettize_params_t a_depalettize_params;
if (a && dw)
{
a_data = mpgetbuffer((ccv_nnc_tensor_t*)a);
a_dataof = (size_t)mpgetoffset((ccv_nnc_tensor_t*)a);
if (CCV_GET_DATA_TYPE(a->info.datatype) == CCV_QX)
{
ccv_nnc_tensor_param_t a_params = a->info;
const size_t count = ccv_nnc_tensor_count(a_params);
const int qbits = (a_params.datatype & 0xf00) >> 8;
const int number_in_blocks = a_params.reserved;
a_depalettize_params = (ccv_nnc_mfa_depalettize_params_t){
.data_type = mtl_data_type,
.qbits = (uint32_t)qbits,
.number_in_blocks = (uint32_t)number_in_blocks,
.length = (uint64_t)count,
};
ccv_nnc_mfa_prepare_depalettize(context, a_depalettize_params);
a_data = scratch;
a_dataof = 0;
}
}
mtl_buffer_t* w_data = 0;
size_t w_dataof = 0;
ccv_nnc_mfa_depalettize_params_t w_depalettize_params;
if (w && h)
{
w_data = mpgetbuffer((ccv_nnc_tensor_t*)w);
w_dataof = (size_t)mpgetoffset((ccv_nnc_tensor_t*)w);
if (CCV_GET_DATA_TYPE(w->info.datatype) == CCV_QX)
{
ccv_nnc_tensor_param_t w_params = w->info;
const size_t count = ccv_nnc_tensor_count(w_params);
const int qbits = (w_params.datatype & 0xf00) >> 8;
const int number_in_blocks = w_params.reserved;
w_depalettize_params = (ccv_nnc_mfa_depalettize_params_t){
.data_type = mtl_data_type,
.qbits = (uint32_t)qbits,
.number_in_blocks = (uint32_t)number_in_blocks,
.length = (uint64_t)count,
};
ccv_nnc_mfa_prepare_depalettize(context, w_depalettize_params);
w_data = scratch;
w_dataof = a_data_size;
}
}
ccv_nnc_mfa_gemm_params_t h_params;
// On supported devices, use Metal directly.
if (h)
Expand Down Expand Up @@ -784,18 +856,32 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
// faster the >50 µs penalty for MPSGraph (probably why
// MPSMatrixMultiplication is faster for GEMM).
mtl_command_batch_t* command_batch = ccv_nnc_stream_context_start_command_batch(stream_context);

if (h)
{
if (CCV_GET_DATA_TYPE(w->info.datatype) == CCV_QX)
{
mtl_buffer_t* tensors[3] = {
mpgetbuffer((ccv_nnc_tensor_t*)w), // A
(mtl_buffer_t*)scratch, // B
NULL,
};
size_t tensor_offsets[2] = {
w->dataof, // A offset
a_data_size, // B offset
};
ccv_nnc_mfa_encode_depalettize(context, w_depalettize_params, command_batch, tensors, tensor_offsets);
}
if (is_transpose_a)
{
mtl_buffer_t* tensors[4] = {
mpgetbuffer((ccv_nnc_tensor_t*)w), // A
w_data, // A
mpgetbuffer((ccv_nnc_tensor_t*)g), // B
mpgetbuffer((ccv_nnc_tensor_t*)h), // C
NULL,
};
size_t tensor_offsets[4] = {
w->dataof, // A offset
w_dataof, // A offset
g->dataof, // B offset
h->dataof, // C offset
0, // D offset
Expand All @@ -804,13 +890,13 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
} else {
mtl_buffer_t* tensors[4] = {
mpgetbuffer((ccv_nnc_tensor_t*)g), // A
mpgetbuffer((ccv_nnc_tensor_t*)w), // B
w_data, // B
mpgetbuffer((ccv_nnc_tensor_t*)h), // C
NULL,
};
size_t tensor_offsets[4] = {
g->dataof, // A offset
w->dataof, // B offset
w_dataof, // B offset
h->dataof, // C offset
0, // D offset
};
Expand All @@ -819,30 +905,43 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
}
if (dw)
{
if (CCV_GET_DATA_TYPE(a->info.datatype) == CCV_QX)
{
mtl_buffer_t* tensors[3] = {
mpgetbuffer((ccv_nnc_tensor_t*)a), // A
(mtl_buffer_t*)scratch, // B
NULL,
};
size_t tensor_offsets[2] = {
a->dataof, // A offset
0, // B offset
};
ccv_nnc_mfa_encode_depalettize(context, a_depalettize_params, command_batch, tensors, tensor_offsets);
}
if (is_transpose_w)
{
mtl_buffer_t* tensors[4] = {
mpgetbuffer((ccv_nnc_tensor_t*)g), // A
mpgetbuffer((ccv_nnc_tensor_t*)a), // B
a_data, // B
mpgetbuffer((ccv_nnc_tensor_t*)dw), // C
NULL,
};
size_t tensor_offsets[4] = {
g->dataof, // A offset
a->dataof, // B offset
a_dataof, // B offset
dw->dataof, // C offset
0, // D offset
};
ccv_nnc_mfa_encode_gemm(context, dw_params, command_batch, tensors, tensor_offsets);
} else {
mtl_buffer_t* tensors[4] = {
mpgetbuffer((ccv_nnc_tensor_t*)a), // A
a_data, // A
mpgetbuffer((ccv_nnc_tensor_t*)g), // B
mpgetbuffer((ccv_nnc_tensor_t*)dw), // C
NULL,
};
size_t tensor_offsets[4] = {
a->dataof, // A offset
a_dataof, // A offset
g->dataof, // B offset
dw->dataof, // C offset
0, // D offset
Expand All @@ -852,7 +951,90 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
}
ccv_nnc_stream_context_finish_command_batch(stream_context, command_batch);
} else {
MPSCommandBuffer* command_buffer = ccv_nnc_stream_context_start_mps_command_buffer(stream_context);
mtl_buffer_t* a_data = 0;
size_t a_dataof = 0;
if (a && dw)
{
a_data = mpgetbuffer((ccv_nnc_tensor_t*)a);
a_dataof = (size_t)mpgetoffset((ccv_nnc_tensor_t*)a);
}
mtl_buffer_t* w_data = 0;
size_t w_dataof = 0;
if (w && h)
{
w_data = mpgetbuffer((ccv_nnc_tensor_t*)w);
w_dataof = (size_t)mpgetoffset((ccv_nnc_tensor_t*)w);
}
MPSCommandBuffer* command_buffer;
if ((a && dw && CCV_GET_DATA_TYPE(a->info.datatype) == CCV_QX) || (w && h && CCV_GET_DATA_TYPE(w->info.datatype) == CCV_QX))
{
mtl_buffer_t* scratch = 0;
if (a_data_size + w_data_size > 0)
scratch = ccv_nnc_mfa_request_scratch(context, a_data_size + w_data_size);
ccv_nnc_mfa_depalettize_params_t a_depalettize_params;
if (a && dw && CCV_GET_DATA_TYPE(a->info.datatype) == CCV_QX)
{
ccv_nnc_tensor_param_t a_params = a->info;
const size_t count = ccv_nnc_tensor_count(a_params);
const int qbits = (a_params.datatype & 0xf00) >> 8;
const int number_in_blocks = a_params.reserved;
a_depalettize_params = (ccv_nnc_mfa_depalettize_params_t){
.data_type = mtl_data_type,
.qbits = (uint32_t)qbits,
.number_in_blocks = (uint32_t)number_in_blocks,
.length = (uint64_t)count,
};
ccv_nnc_mfa_prepare_depalettize(context, a_depalettize_params);
a_data = scratch;
a_dataof = 0;
}
ccv_nnc_mfa_depalettize_params_t w_depalettize_params;
if (w && h && CCV_GET_DATA_TYPE(w->info.datatype) == CCV_QX)
{
ccv_nnc_tensor_param_t w_params = w->info;
const size_t count = ccv_nnc_tensor_count(w_params);
const int qbits = (w_params.datatype & 0xf00) >> 8;
const int number_in_blocks = w_params.reserved;
w_depalettize_params = (ccv_nnc_mfa_depalettize_params_t){
.data_type = mtl_data_type,
.qbits = (uint32_t)qbits,
.number_in_blocks = (uint32_t)number_in_blocks,
.length = (uint64_t)count,
};
ccv_nnc_mfa_prepare_depalettize(context, w_depalettize_params);
w_data = scratch;
w_dataof = a_data_size;
}
mtl_command_batch_t* command_batch = ccv_nnc_stream_context_start_command_batch(stream_context);
if (a && dw && CCV_GET_DATA_TYPE(a->info.datatype) == CCV_QX)
{
mtl_buffer_t* tensors[3] = {
mpgetbuffer((ccv_nnc_tensor_t*)a), // A
(mtl_buffer_t*)scratch, // B
NULL,
};
size_t tensor_offsets[2] = {
a->dataof, // A offset
0, // B offset
};
ccv_nnc_mfa_encode_depalettize(context, a_depalettize_params, command_batch, tensors, tensor_offsets);
}
if (w && h && CCV_GET_DATA_TYPE(w->info.datatype) == CCV_QX)
{
mtl_buffer_t* tensors[3] = {
mpgetbuffer((ccv_nnc_tensor_t*)w), // A
(mtl_buffer_t*)scratch, // B
NULL,
};
size_t tensor_offsets[2] = {
w->dataof, // A offset
a_data_size, // B offset
};
ccv_nnc_mfa_encode_depalettize(context, w_depalettize_params, command_batch, tensors, tensor_offsets);
}
command_buffer = ccv_nnc_stream_context_finish_command_batch_encoding_and_return_mps_command_buffer(stream_context, command_batch);
} else // Otherwise, incur the ~10-50 microsecond latency of MPS.
command_buffer = ccv_nnc_stream_context_start_mps_command_buffer(stream_context);

if (h) {
assert(w); // when calculate h, w must exist
Expand Down Expand Up @@ -906,7 +1088,7 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
[resultTensors addObject:mps_h];
});
MPSGraphTensorData* data_g = ccv_nnc_mps_graph_tensor_data(g, g->info.dim, g->stride);
MPSGraphTensorData* data_w = ccv_nnc_mps_graph_tensor_data(w, w->info.dim, w->stride);
MPSGraphTensorData* data_w = ccv_nnc_mps_graph_tensor_data_with_buffer(w, w->info.dim, w->stride, w_data, w_dataof);
MPSGraphTensorData* data[] = {data_g, data_w};
ccv_nnc_mps_graph_executable_result(executable, command_buffer, @[data[indices[0]], data[indices[1]]], &h, (int*[]){ h->info.dim }, (int*[]){ h->stride }, 1);
}
Expand Down Expand Up @@ -964,7 +1146,7 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint

});
MPSGraphTensorData* data_g = ccv_nnc_mps_graph_tensor_data(g, g->info.dim, g->stride);
MPSGraphTensorData* data_a = ccv_nnc_mps_graph_tensor_data(a, a->info.dim, a->stride);
MPSGraphTensorData* data_a = ccv_nnc_mps_graph_tensor_data_with_buffer(a, a->info.dim, a->stride, a_data, a_dataof);
MPSGraphTensorData* data[] = {data_g, data_a};
ccv_nnc_mps_graph_executable_result(executable_dw, command_buffer, @[data[dw_indices[0]], data[dw_indices[1]]], &dw , (int*[]){ dw->info.dim }, (int*[]){ dw->stride }, 1);
}
Expand Down
48 changes: 45 additions & 3 deletions lib/nnc/cmd/convolution/mps/ccv_nnc_conv_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -421,9 +421,46 @@ static int _ccv_nnc_conv_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
ccv_nnc_tensor_view_t* db = output_size > 2 ? (ccv_nnc_tensor_view_t*)outputs[2] : 0;

@autoreleasepool {
MPSCommandBuffer* command_buffer = ccv_nnc_stream_context_start_mps_command_buffer(stream_context);
MPSCommandBuffer* command_buffer = 0;
ccv_nnc_mfa_context_t* context = ccv_nnc_default_mfa_context();

if (h) {
mtl_buffer_t* w_data = mpgetbuffer((ccv_nnc_tensor_t*)w);
size_t w_dataof = (size_t)mpgetoffset((ccv_nnc_tensor_t*)w);
if (CCV_GET_DATA_TYPE(w->info.datatype) == CCV_QX)
{
ccv_nnc_tensor_param_t w_params = w->info;
const int palette_datatype = (w_params.datatype & 0xff) << 12;
ccv_nnc_tensor_param_t depalettize_w_params = w_params;
depalettize_w_params.datatype = palette_datatype;
depalettize_w_params.reserved = 0;
size_t w_data_size = ccv_nnc_tensor_data_size(depalettize_w_params);
const size_t count = ccv_nnc_tensor_count(w_params);
const int qbits = (w_params.datatype & 0xf00) >> 8;
const int number_in_blocks = w_params.reserved;
ccv_nnc_mfa_depalettize_params_t w_depalettize_params = {
.data_type = palette_datatype == CCV_16F ? 16 : 3,
.qbits = (uint32_t)qbits,
.number_in_blocks = (uint32_t)number_in_blocks,
.length = (uint64_t)count,
};
ccv_nnc_mfa_prepare_depalettize(context, w_depalettize_params);
w_data = ccv_nnc_mfa_request_scratch(context, w_data_size);
w_dataof = 0;
mtl_command_batch_t* command_batch = ccv_nnc_stream_context_start_command_batch(stream_context);
mtl_buffer_t* tensors[3] = {
mpgetbuffer((ccv_nnc_tensor_t*)w), // A
(mtl_buffer_t*)w_data, // B
NULL,
};
size_t tensor_offsets[2] = {
w->dataof, // A offset
0, // B offset
};
ccv_nnc_mfa_encode_depalettize(context, w_depalettize_params, command_batch, tensors, tensor_offsets);
command_buffer = ccv_nnc_stream_context_finish_command_batch_encoding_and_return_mps_command_buffer(stream_context, command_batch);
} else
command_buffer = ccv_nnc_stream_context_start_mps_command_buffer(stream_context);
// [output gradient]
ccv_nnc_mps_graph_key_t key = ccv_nnc_mps_graph_key_new(cmd, 0, hint, flags, inputs, input_size, outputs, output_size);
int indices[2];
Expand Down Expand Up @@ -456,12 +493,14 @@ static int _ccv_nnc_conv_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
[resultTensors addObject:mps_h];
});
MPSGraphTensorData* data_g = ccv_nnc_mps_graph_tensor_data(g, g->info.dim, g->stride);
MPSGraphTensorData* data_w = ccv_nnc_mps_graph_tensor_data(w, w->info.dim, w->stride);
MPSGraphTensorData* data_w = ccv_nnc_mps_graph_tensor_data_with_buffer(w, w->info.dim, w->stride, w_data, w_dataof);
MPSGraphTensorData* data[] = {data_g, data_w};
ccv_nnc_mps_graph_executable_result(executable, command_buffer, @[data[indices[0]], data[indices[1]]], &h, (int*[]){ h->info.dim }, (int*[]){ h->stride }, 1);
}

if (dw) {
if (!command_buffer)
command_buffer = ccv_nnc_stream_context_start_mps_command_buffer(stream_context);
// [weight updates]
ccv_nnc_mps_graph_key_t dw_key = ccv_nnc_mps_graph_key_new(cmd, 1, hint, flags, inputs, input_size, outputs, output_size);
int dw_indices[2];
Expand Down Expand Up @@ -501,6 +540,8 @@ static int _ccv_nnc_conv_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
}

if (db) {
if (!command_buffer)
command_buffer = ccv_nnc_stream_context_start_mps_command_buffer(stream_context);
// [bias updates]
ccv_nnc_mps_graph_key_t db_key = ccv_nnc_mps_graph_key_new(cmd, 2, hint, flags, inputs, input_size, outputs, output_size);
int db_indices[1];
Expand All @@ -526,7 +567,8 @@ static int _ccv_nnc_conv_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
ccv_nnc_mps_graph_executable_result(executable_db, command_buffer, @[data_g], &db, (int*[]){ db->info.dim }, (int*[]){ dw->info.dim }, 1);
}

ccv_nnc_stream_context_finish_mps_command_buffer(stream_context, command_buffer);
if (command_buffer)
ccv_nnc_stream_context_finish_mps_command_buffer(stream_context, command_buffer);
}
return CCV_NNC_EXEC_SUCCESS;
}
Expand Down

0 comments on commit 0ce86f0

Please sign in to comment.