diff --git a/lib/nnc/cmd/upsample/mps/ccv_nnc_upsample_mps.m b/lib/nnc/cmd/upsample/mps/ccv_nnc_upsample_mps.m index ad69617ee..02e4ea454 100644 --- a/lib/nnc/cmd/upsample/mps/ccv_nnc_upsample_mps.m +++ b/lib/nnc/cmd/upsample/mps/ccv_nnc_upsample_mps.m @@ -216,6 +216,20 @@ static int _ccv_nnc_upsample_bilinear_back(const ccv_nnc_cmd_t cmd, const ccv_nn // for unknown reason, MPS handling NHWC as NHCW... // explicitly transpose input and output for NHWC [inputSize exchangeObjectAtIndex:2 withObjectAtIndex:3]; + [inputSize exchangeObjectAtIndex:1 withObjectAtIndex:2]; + int t = adim[3]; + adim[3] = adim[2]; + adim[2] = adim[1]; + adim[1] = t; + t = astride[3]; + astride[3] = astride[2]; + astride[2] = astride[1]; + astride[1] = t; + ccv_nnc_tensor_view_t at = ccv_nnc_get_tensor_view(a); + at.contiguous = 0; + at.off = 0; + at.type |= CCV_TENSOR_VIEW; + memcpy(at.stride, astride, sizeof(at.stride)); @autoreleasepool { MPSCommandBuffer* command_buffer = ccv_nnc_stream_context_start_mps_command_buffer(stream_context); ccv_nnc_mps_graph_key_t key = ccv_nnc_mps_graph_key_new(cmd, 0, hint, flags, inputs, input_size, outputs, output_size); @@ -226,8 +240,9 @@ static int _ccv_nnc_upsample_bilinear_back(const ccv_nnc_cmd_t cmd, const ccv_nn [inputTensors addObject:mps_input_b]; MPSGraphShapedType* mps_b_shape = ccv_nnc_mps_graph_tensor_input_shape(b, bdim_r, bstride_r); [inputShapedTypes addObject:mps_b_shape]; - // NHWC to NHCW + // NHWC to NHCW. This probably is not needed on iOS 17. mps_b = [graph transposeTensor:mps_b dimension:-1 withDimension:-2 name:nil]; + mps_b = [graph transposeTensor:mps_b dimension:-2 withDimension:-3 name:nil]; MPSGraphTensor* inputSizeTensor = [graph constantWithScalar:0 shape:inputSize dataType:ccv_nnc_mps_datatype(b->info.datatype)]; MPSGraphTensor* mps_a = [graph resizeWithGradientTensor:mps_b @@ -235,14 +250,13 @@ static int _ccv_nnc_upsample_bilinear_back(const ccv_nnc_cmd_t cmd, const ccv_nn mode:MPSGraphResizeBilinear centerResult:YES alignCorners:NO - layout:MPSGraphTensorNamedDataLayoutNHWC + layout:MPSGraphTensorNamedDataLayoutNCHW name:nil]; - // NHCW to NHWC - mps_a = [graph transposeTensor:mps_a dimension:-1 withDimension:-2 name:nil]; + // No need for NHCW to NHWC [resultTensors addObject:mps_a]; }); MPSGraphTensorData* data_b = ccv_nnc_mps_graph_tensor_data(b, bdim_r, bstride_r); - ccv_nnc_mps_graph_executable_result(executable, command_buffer, @[data_b], &a , (int*[]){ adim_r }, (int*[]){ astride_r }, 1); + ccv_nnc_mps_graph_executable_result(executable, command_buffer, @[data_b], (ccv_nnc_tensor_view_t*[]){ &at }, (int*[]){ adim_r }, (int*[]){ astride_r }, 1); ccv_nnc_stream_context_finish_mps_command_buffer(stream_context, command_buffer); @@ -315,7 +329,21 @@ static int _ccv_nnc_upsample_nearest_back(const ccv_nnc_cmd_t cmd, const ccv_nnc assert(inputSize.count == 4); // for unknown reason, MPS handling NHWC as NHCW... // explicitly transpose input and output for NHWC - [inputSize exchangeObjectAtIndex:2 withObjectAtIndex:3]; + [inputSize exchangeObjectAtIndex:2 withObjectAtIndex:3]; + [inputSize exchangeObjectAtIndex:1 withObjectAtIndex:2]; + int t = adim[3]; + adim[3] = adim[2]; + adim[2] = adim[1]; + adim[1] = t; + t = astride[3]; + astride[3] = astride[2]; + astride[2] = astride[1]; + astride[1] = t; + ccv_nnc_tensor_view_t at = ccv_nnc_get_tensor_view(a); + at.contiguous = 0; + at.off = 0; + at.type |= CCV_TENSOR_VIEW; + memcpy(at.stride, astride, sizeof(at.stride)); @autoreleasepool { MPSCommandBuffer* command_buffer = ccv_nnc_stream_context_start_mps_command_buffer(stream_context); @@ -329,23 +357,23 @@ static int _ccv_nnc_upsample_nearest_back(const ccv_nnc_cmd_t cmd, const ccv_nnc [inputShapedTypes addObject:mps_b_shape]; MPSGraphTensor* inputSizeTensor = [graph constantWithScalar:0 shape:inputSize dataType:ccv_nnc_mps_datatype(b->info.datatype)]; - // NHWC to NHCW + // NHWC to NHCW. This probably is not needed on iOS 17. mps_b = [graph transposeTensor:mps_b dimension:-1 withDimension:-2 name:nil]; + mps_b = [graph transposeTensor:mps_b dimension:-2 withDimension:-3 name:nil]; MPSGraphTensor* mps_a = [graph resizeWithGradientTensor:mps_b input:inputSizeTensor mode:MPSGraphResizeNearest centerResult:YES alignCorners:NO - layout:MPSGraphTensorNamedDataLayoutNHWC + layout:MPSGraphTensorNamedDataLayoutNCHW name:nil]; - // NHCW to NHWC - mps_a = [graph transposeTensor:mps_a dimension:-1 withDimension:-2 name:nil]; + // No need for NHCW to NHWC [resultTensors addObject:mps_a]; }); MPSGraphTensorData* data_b = ccv_nnc_mps_graph_tensor_data(b, bdim, bstride); - ccv_nnc_mps_graph_executable_result(executable, command_buffer, @[data_b], &a, (int*[]){ adim }, (int*[]){ astride }, 1); + ccv_nnc_mps_graph_executable_result(executable, command_buffer, @[data_b], (ccv_nnc_tensor_view_t*[]){ &at }, (int*[]){ adim }, (int*[]){ astride }, 1); ccv_nnc_stream_context_finish_mps_command_buffer(stream_context, command_buffer); } }