Skip to content

Commit

Permalink
Less compiled kernels for add op.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Sep 25, 2023
1 parent 0a5a1d1 commit 3fd457b
Showing 1 changed file with 112 additions and 29 deletions.
141 changes: 112 additions & 29 deletions lib/nnc/cmd/blas/mps/ccv_nnc_add_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ static int _ccv_nnc_add_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint,
else
ccv_nnc_mps_export_data(data_a, command_buffer, c, c->info.dim, c->stride);
[graph release];
} else {
} else if (p == 0.5 || p == 2 || p == 1.0 / 3 || p == 3 || p == 1.0 / 10 || p == 10) { // Only create specialized kernels for special p values.
ccv_nnc_mps_graph_key_t key = ccv_nnc_mps_graph_key_new(cmd, 0, hint, flags, inputs, input_size, outputs, output_size);
int indices[1];
MPSGraphExecutable* executable = ccv_nnc_mps_graph_executable_cache(key, indices, ^void (MPSGraph* graph, NSMutableArray<MPSGraphTensor*>* inputTensors, NSMutableArray<MPSGraphShapedType*>* inputShapedTypes, NSMutableArray<MPSGraphTensor*>* resultTensors) {
Expand All @@ -51,6 +51,29 @@ static int _ccv_nnc_add_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint,
});
MPSGraphTensorData* data_a = ccv_nnc_mps_graph_tensor_data(a, a->info.dim, a->stride);
ccv_nnc_mps_graph_executable_result(executable, command_buffer, @[data_a], &c, (int*[]){ c->info.dim }, (int*[]){ c->stride }, 1);
} else {
ccv_nnc_cmd_t cmd_without_p = cmd;
cmd_without_p.info.blas.a[0] = 0;
ccv_nnc_mps_graph_key_t key = ccv_nnc_mps_graph_key_new(cmd_without_p, 1, hint, flags, inputs, input_size, outputs, output_size);
int indices[2];
MPSGraphExecutable* executable = ccv_nnc_mps_graph_executable_cache(key, indices, ^void (MPSGraph* graph, NSMutableArray<MPSGraphTensor*>* inputTensors, NSMutableArray<MPSGraphShapedType*>* inputShapedTypes, NSMutableArray<MPSGraphTensor*>* resultTensors) {
MPSGraphTensor* mps_input_a;
MPSGraphTensor* mps_a = ccv_nnc_mps_graph_tensor_input(graph, a, a->info.dim, a->stride, &mps_input_a);
[inputTensors addObject:mps_input_a];
MPSGraphShapedType* mps_a_shape = ccv_nnc_mps_graph_tensor_input_shape(a, a->info.dim, a->stride);
[inputShapedTypes addObject:mps_a_shape];
MPSGraphTensor* mps_p = [graph placeholderWithShape:@[@1] dataType:ccv_nnc_mps_datatype(a->info.datatype) name:nil];
[inputTensors addObject:mps_p];
MPSGraphShapedType* mps_p_shape = [[MPSGraphShapedType alloc] initWithShape:@[@1] dataType:ccv_nnc_mps_datatype(a->info.datatype)];
[inputShapedTypes addObject:mps_p_shape];
[mps_p_shape release];
MPSGraphTensor* mps_c = [graph multiplicationWithPrimaryTensor:mps_a secondaryTensor:mps_p name:nil];
[resultTensors addObject:mps_c];
});
MPSGraphTensorData* data_a = ccv_nnc_mps_graph_tensor_data(a, a->info.dim, a->stride);
MPSGraphTensorData* data_p = ccv_nnc_mps_graph_constant_data(p, a->info.datatype);
MPSGraphTensorData* data[] = {data_a, data_p};
ccv_nnc_mps_graph_executable_result(executable, command_buffer, @[data[indices[0]], data[indices[1]]], &c, (int*[]){ c->info.dim }, (int*[]){ c->stride }, 1);
}
ccv_nnc_stream_context_finish_mps_command_buffer(stream_context, command_buffer);
}
Expand All @@ -59,36 +82,96 @@ static int _ccv_nnc_add_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint,
const ccv_nnc_tensor_view_t* const b = (const ccv_nnc_tensor_view_t*)inputs[1];
@autoreleasepool {
MPSCommandBuffer* command_buffer = ccv_nnc_stream_context_start_mps_command_buffer(stream_context);
ccv_nnc_mps_graph_key_t key = ccv_nnc_mps_graph_key_new(cmd, 0, hint, flags, inputs, input_size, outputs, output_size);
int indices[2];
MPSGraphExecutable* executable = ccv_nnc_mps_graph_executable_cache(key, indices, ^void (MPSGraph* graph, NSMutableArray<MPSGraphTensor*>* inputTensors, NSMutableArray<MPSGraphShapedType*>* inputShapedTypes, NSMutableArray<MPSGraphTensor*>* resultTensors) {
MPSGraphTensor* mps_input_a;
MPSGraphTensor* mps_a = ccv_nnc_mps_graph_tensor_input(graph, a, a->info.dim, a->stride, &mps_input_a);
[inputTensors addObject:mps_input_a];
MPSGraphShapedType* mps_a_shape = ccv_nnc_mps_graph_tensor_input_shape(a, a->info.dim, a->stride);
[inputShapedTypes addObject:mps_a_shape];
if (p != 1)
{
MPSGraphTensor* mps_p = [graph constantWithScalar:p dataType:ccv_nnc_mps_datatype(a->info.datatype)];
if (p == 1 && q == 1)
{
ccv_nnc_mps_graph_key_t key = ccv_nnc_mps_graph_key_new(cmd, 0, hint, flags, inputs, input_size, outputs, output_size);
int indices[2];
MPSGraphExecutable* executable = ccv_nnc_mps_graph_executable_cache(key, indices, ^void (MPSGraph* graph, NSMutableArray<MPSGraphTensor*>* inputTensors, NSMutableArray<MPSGraphShapedType*>* inputShapedTypes, NSMutableArray<MPSGraphTensor*>* resultTensors) {
MPSGraphTensor* mps_input_a;
MPSGraphTensor* mps_a = ccv_nnc_mps_graph_tensor_input(graph, a, a->info.dim, a->stride, &mps_input_a);
[inputTensors addObject:mps_input_a];
MPSGraphShapedType* mps_a_shape = ccv_nnc_mps_graph_tensor_input_shape(a, a->info.dim, a->stride);
[inputShapedTypes addObject:mps_a_shape];
MPSGraphTensor* mps_input_b;
MPSGraphTensor* mps_b = ccv_nnc_mps_graph_tensor_input(graph, b, b->info.dim, b->stride, &mps_input_b);
[inputTensors addObject:mps_input_b];
MPSGraphShapedType* mps_b_shape = ccv_nnc_mps_graph_tensor_input_shape(b, b->info.dim, b->stride);
[inputShapedTypes addObject:mps_b_shape];
MPSGraphTensor* mps_c = [graph additionWithPrimaryTensor:mps_a secondaryTensor:mps_b name:nil];
[resultTensors addObject:mps_c];
});
MPSGraphTensorData* data_a = ccv_nnc_mps_graph_tensor_data(a, a->info.dim, a->stride);
MPSGraphTensorData* data_b = ccv_nnc_mps_graph_tensor_data(b, b->info.dim, b->stride);
MPSGraphTensorData* data[] = {data_a, data_b};
ccv_nnc_mps_graph_executable_result(executable, command_buffer, @[data[indices[0]], data[indices[1]]], &c, (int*[]){ c->info.dim }, (int*[]){ c->stride }, 1);
} else if ((p == 1 || p == 0.5 || p == 2 || p == 1.0 / 3 || p == 3 || p == 1.0 / 10 || p == 10) && (q == 1 || q == 0.5 || q == 2 || q == 1.0 / 3 || q == 3 || q == 1.0 / 10 || q == 10)) { // Only create specialized kernels for special p / q values.
ccv_nnc_mps_graph_key_t key = ccv_nnc_mps_graph_key_new(cmd, 0, hint, flags, inputs, input_size, outputs, output_size);
int indices[2];
MPSGraphExecutable* executable = ccv_nnc_mps_graph_executable_cache(key, indices, ^void (MPSGraph* graph, NSMutableArray<MPSGraphTensor*>* inputTensors, NSMutableArray<MPSGraphShapedType*>* inputShapedTypes, NSMutableArray<MPSGraphTensor*>* resultTensors) {
MPSGraphTensor* mps_input_a;
MPSGraphTensor* mps_a = ccv_nnc_mps_graph_tensor_input(graph, a, a->info.dim, a->stride, &mps_input_a);
[inputTensors addObject:mps_input_a];
MPSGraphShapedType* mps_a_shape = ccv_nnc_mps_graph_tensor_input_shape(a, a->info.dim, a->stride);
[inputShapedTypes addObject:mps_a_shape];
if (p != 1)
{
MPSGraphTensor* mps_p = [graph constantWithScalar:p dataType:ccv_nnc_mps_datatype(a->info.datatype)];
mps_a = [graph multiplicationWithPrimaryTensor:mps_a secondaryTensor:mps_p name:nil];
}
MPSGraphTensor* mps_input_b;
MPSGraphTensor* mps_b = ccv_nnc_mps_graph_tensor_input(graph, b, b->info.dim, b->stride, &mps_input_b);
[inputTensors addObject:mps_input_b];
MPSGraphShapedType* mps_b_shape = ccv_nnc_mps_graph_tensor_input_shape(b, b->info.dim, b->stride);
[inputShapedTypes addObject:mps_b_shape];
if (q != 1)
{
MPSGraphTensor* mps_q = [graph constantWithScalar:q dataType:ccv_nnc_mps_datatype(b->info.datatype)];
mps_b = [graph multiplicationWithPrimaryTensor:mps_b secondaryTensor:mps_q name:nil];
}
MPSGraphTensor* mps_c = [graph additionWithPrimaryTensor:mps_a secondaryTensor:mps_b name:nil];
[resultTensors addObject:mps_c];
});
MPSGraphTensorData* data_a = ccv_nnc_mps_graph_tensor_data(a, a->info.dim, a->stride);
MPSGraphTensorData* data_b = ccv_nnc_mps_graph_tensor_data(b, b->info.dim, b->stride);
MPSGraphTensorData* data[] = {data_a, data_b};
ccv_nnc_mps_graph_executable_result(executable, command_buffer, @[data[indices[0]], data[indices[1]]], &c, (int*[]){ c->info.dim }, (int*[]){ c->stride }, 1);
} else {
ccv_nnc_cmd_t cmd_without_p_q = cmd;
cmd_without_p_q.info.blas.a[0] = 0;
cmd_without_p_q.info.blas.a[1] = 0;
ccv_nnc_mps_graph_key_t key = ccv_nnc_mps_graph_key_new(cmd_without_p_q, 1, hint, flags, inputs, input_size, outputs, output_size);
int indices[4];
MPSGraphExecutable* executable = ccv_nnc_mps_graph_executable_cache(key, indices, ^void (MPSGraph* graph, NSMutableArray<MPSGraphTensor*>* inputTensors, NSMutableArray<MPSGraphShapedType*>* inputShapedTypes, NSMutableArray<MPSGraphTensor*>* resultTensors) {
MPSGraphTensor* mps_input_a;
MPSGraphTensor* mps_a = ccv_nnc_mps_graph_tensor_input(graph, a, a->info.dim, a->stride, &mps_input_a);
[inputTensors addObject:mps_input_a];
MPSGraphShapedType* mps_a_shape = ccv_nnc_mps_graph_tensor_input_shape(a, a->info.dim, a->stride);
[inputShapedTypes addObject:mps_a_shape];
MPSGraphTensor* mps_p = [graph placeholderWithShape:@[@1] dataType:ccv_nnc_mps_datatype(a->info.datatype) name:nil];
[inputTensors addObject:mps_p];
MPSGraphShapedType* mps_p_shape = [[MPSGraphShapedType alloc] initWithShape:@[@1] dataType:ccv_nnc_mps_datatype(a->info.datatype)];
[inputShapedTypes addObject:mps_p_shape];
mps_a = [graph multiplicationWithPrimaryTensor:mps_a secondaryTensor:mps_p name:nil];
}
MPSGraphTensor* mps_input_b;
MPSGraphTensor* mps_b = ccv_nnc_mps_graph_tensor_input(graph, b, b->info.dim, b->stride, &mps_input_b);
[inputTensors addObject:mps_input_b];
MPSGraphShapedType* mps_b_shape = ccv_nnc_mps_graph_tensor_input_shape(b, b->info.dim, b->stride);
[inputShapedTypes addObject:mps_b_shape];
if (q != 1)
{
MPSGraphTensor* mps_q = [graph constantWithScalar:q dataType:ccv_nnc_mps_datatype(b->info.datatype)];
MPSGraphTensor* mps_input_b;
MPSGraphTensor* mps_b = ccv_nnc_mps_graph_tensor_input(graph, b, b->info.dim, b->stride, &mps_input_b);
[inputTensors addObject:mps_input_b];
MPSGraphShapedType* mps_b_shape = ccv_nnc_mps_graph_tensor_input_shape(b, b->info.dim, b->stride);
[inputShapedTypes addObject:mps_b_shape];
MPSGraphTensor* mps_q = [graph placeholderWithShape:@[@1] dataType:ccv_nnc_mps_datatype(b->info.datatype) name:nil];
[inputTensors addObject:mps_q];
MPSGraphShapedType* mps_q_shape = [[MPSGraphShapedType alloc] initWithShape:@[@1] dataType:ccv_nnc_mps_datatype(b->info.datatype)];
[inputShapedTypes addObject:mps_q_shape];
mps_b = [graph multiplicationWithPrimaryTensor:mps_b secondaryTensor:mps_q name:nil];
}
MPSGraphTensor* mps_c = [graph additionWithPrimaryTensor:mps_a secondaryTensor:mps_b name:nil];
[resultTensors addObject:mps_c];
});
MPSGraphTensorData* data_a = ccv_nnc_mps_graph_tensor_data(a, a->info.dim, a->stride);
MPSGraphTensorData* data_b = ccv_nnc_mps_graph_tensor_data(b, b->info.dim, b->stride);
MPSGraphTensorData* data[] = {data_a, data_b};
ccv_nnc_mps_graph_executable_result(executable, command_buffer, @[data[indices[0]], data[indices[1]]], &c, (int*[]){ c->info.dim }, (int*[]){ c->stride }, 1);
MPSGraphTensor* mps_c = [graph additionWithPrimaryTensor:mps_a secondaryTensor:mps_b name:nil];
[resultTensors addObject:mps_c];
});
MPSGraphTensorData* data_a = ccv_nnc_mps_graph_tensor_data(a, a->info.dim, a->stride);
MPSGraphTensorData* data_p = ccv_nnc_mps_graph_constant_data(p, a->info.datatype);
MPSGraphTensorData* data_b = ccv_nnc_mps_graph_tensor_data(b, b->info.dim, b->stride);
MPSGraphTensorData* data_q = ccv_nnc_mps_graph_constant_data(q, a->info.datatype);
MPSGraphTensorData* data[] = {data_a, data_p, data_b, data_q};
ccv_nnc_mps_graph_executable_result(executable, command_buffer, @[data[indices[0]], data[indices[1]], data[indices[2]], data[indices[3]]], &c, (int*[]){ c->info.dim }, (int*[]){ c->stride }, 1);
}
ccv_nnc_stream_context_finish_mps_command_buffer(stream_context, command_buffer);
}
return CCV_NNC_EXEC_SUCCESS;
Expand Down

0 comments on commit 3fd457b

Please sign in to comment.