Skip to content

Commit

Permalink
Fix upsample back on macOS 14.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Sep 20, 2023
1 parent d3e2463 commit 1ae66bb
Showing 1 changed file with 39 additions and 11 deletions.
50 changes: 39 additions & 11 deletions lib/nnc/cmd/upsample/mps/ccv_nnc_upsample_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,20 @@ static int _ccv_nnc_upsample_bilinear_back(const ccv_nnc_cmd_t cmd, const ccv_nn
// for unknown reason, MPS handling NHWC as NHCW...
// explicitly transpose input and output for NHWC
[inputSize exchangeObjectAtIndex:2 withObjectAtIndex:3];
[inputSize exchangeObjectAtIndex:1 withObjectAtIndex:2];
int t = adim[3];
adim[3] = adim[2];
adim[2] = adim[1];
adim[1] = t;
t = astride[3];
astride[3] = astride[2];
astride[2] = astride[1];
astride[1] = t;
ccv_nnc_tensor_view_t at = ccv_nnc_get_tensor_view(a);
at.contiguous = 0;
at.off = 0;
at.type |= CCV_TENSOR_VIEW;
memcpy(at.stride, astride, sizeof(at.stride));
@autoreleasepool {
MPSCommandBuffer* command_buffer = ccv_nnc_stream_context_start_mps_command_buffer(stream_context);
ccv_nnc_mps_graph_key_t key = ccv_nnc_mps_graph_key_new(cmd, 0, hint, flags, inputs, input_size, outputs, output_size);
Expand All @@ -226,23 +240,23 @@ static int _ccv_nnc_upsample_bilinear_back(const ccv_nnc_cmd_t cmd, const ccv_nn
[inputTensors addObject:mps_input_b];
MPSGraphShapedType* mps_b_shape = ccv_nnc_mps_graph_tensor_input_shape(b, bdim_r, bstride_r);
[inputShapedTypes addObject:mps_b_shape];
// NHWC to NHCW
// NHWC to NHCW. This probably is not needed on iOS 17.
mps_b = [graph transposeTensor:mps_b dimension:-1 withDimension:-2 name:nil];
mps_b = [graph transposeTensor:mps_b dimension:-2 withDimension:-3 name:nil];
MPSGraphTensor* inputSizeTensor = [graph constantWithScalar:0 shape:inputSize dataType:ccv_nnc_mps_datatype(b->info.datatype)];

MPSGraphTensor* mps_a = [graph resizeWithGradientTensor:mps_b
input:inputSizeTensor
mode:MPSGraphResizeBilinear
centerResult:YES
alignCorners:NO
layout:MPSGraphTensorNamedDataLayoutNHWC
layout:MPSGraphTensorNamedDataLayoutNCHW
name:nil];
// NHCW to NHWC
mps_a = [graph transposeTensor:mps_a dimension:-1 withDimension:-2 name:nil];
// No need for NHCW to NHWC
[resultTensors addObject:mps_a];
});
MPSGraphTensorData* data_b = ccv_nnc_mps_graph_tensor_data(b, bdim_r, bstride_r);
ccv_nnc_mps_graph_executable_result(executable, command_buffer, @[data_b], &a , (int*[]){ adim_r }, (int*[]){ astride_r }, 1);
ccv_nnc_mps_graph_executable_result(executable, command_buffer, @[data_b], (ccv_nnc_tensor_view_t*[]){ &at }, (int*[]){ adim_r }, (int*[]){ astride_r }, 1);

ccv_nnc_stream_context_finish_mps_command_buffer(stream_context, command_buffer);

Expand Down Expand Up @@ -315,7 +329,21 @@ static int _ccv_nnc_upsample_nearest_back(const ccv_nnc_cmd_t cmd, const ccv_nnc
assert(inputSize.count == 4);
// for unknown reason, MPS handling NHWC as NHCW...
// explicitly transpose input and output for NHWC
[inputSize exchangeObjectAtIndex:2 withObjectAtIndex:3];
[inputSize exchangeObjectAtIndex:2 withObjectAtIndex:3];
[inputSize exchangeObjectAtIndex:1 withObjectAtIndex:2];
int t = adim[3];
adim[3] = adim[2];
adim[2] = adim[1];
adim[1] = t;
t = astride[3];
astride[3] = astride[2];
astride[2] = astride[1];
astride[1] = t;
ccv_nnc_tensor_view_t at = ccv_nnc_get_tensor_view(a);
at.contiguous = 0;
at.off = 0;
at.type |= CCV_TENSOR_VIEW;
memcpy(at.stride, astride, sizeof(at.stride));

@autoreleasepool {
MPSCommandBuffer* command_buffer = ccv_nnc_stream_context_start_mps_command_buffer(stream_context);
Expand All @@ -329,23 +357,23 @@ static int _ccv_nnc_upsample_nearest_back(const ccv_nnc_cmd_t cmd, const ccv_nnc
[inputShapedTypes addObject:mps_b_shape];

MPSGraphTensor* inputSizeTensor = [graph constantWithScalar:0 shape:inputSize dataType:ccv_nnc_mps_datatype(b->info.datatype)];
// NHWC to NHCW
// NHWC to NHCW. This probably is not needed on iOS 17.
mps_b = [graph transposeTensor:mps_b dimension:-1 withDimension:-2 name:nil];
mps_b = [graph transposeTensor:mps_b dimension:-2 withDimension:-3 name:nil];

MPSGraphTensor* mps_a = [graph resizeWithGradientTensor:mps_b
input:inputSizeTensor
mode:MPSGraphResizeNearest
centerResult:YES
alignCorners:NO
layout:MPSGraphTensorNamedDataLayoutNHWC
layout:MPSGraphTensorNamedDataLayoutNCHW
name:nil];
// NHCW to NHWC
mps_a = [graph transposeTensor:mps_a dimension:-1 withDimension:-2 name:nil];
// No need for NHCW to NHWC

[resultTensors addObject:mps_a];
});
MPSGraphTensorData* data_b = ccv_nnc_mps_graph_tensor_data(b, bdim, bstride);
ccv_nnc_mps_graph_executable_result(executable, command_buffer, @[data_b], &a, (int*[]){ adim }, (int*[]){ astride }, 1);
ccv_nnc_mps_graph_executable_result(executable, command_buffer, @[data_b], (ccv_nnc_tensor_view_t*[]){ &at }, (int*[]){ adim }, (int*[]){ astride }, 1);
ccv_nnc_stream_context_finish_mps_command_buffer(stream_context, command_buffer);
}
}
Expand Down

0 comments on commit 1ae66bb

Please sign in to comment.