diff --git a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc index ed32b93295f7f..ee6298d52461a 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc @@ -8,6 +8,7 @@ #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/cpu/tensor/slice_helper.h" +#include "core/providers/shared/utils/utils.h" #if defined(__APPLE__) #include "core/providers/coreml/builders/model_builder.h" @@ -32,6 +33,7 @@ class SliceOpBuilder : public BaseOpBuilder { return 10; } + bool HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const override; bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& builder_params, const logging::Logger& logger) const override; }; @@ -145,8 +147,8 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const if (step < 0 && end == -1) { // Special case - stepping backwards up to and including the first index in the dimension. - // In ONNX Slice, we can use end = -1 to represent this. In CoreML, endids = -1 doesn't work like that so we can - // use endmasks to specify the rest of the dimension instead. + // In ONNX Slice, we use end <= -(rank + 1) to represent this. In CoreML, setting endids like that doesn't work, + // so use endmasks to specify the rest of the dimension instead. slice_static->add_endids(-1); // ignored slice_static->add_endmasks(true); } else { @@ -164,6 +166,22 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const #endif // defined(__APPLE__) // Operator support related +bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const { + int32_t input_type; + if (!GetType(*node.InputDefs()[0], input_type, logger)) + return false; + + if (input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && + input_type != ONNX_NAMESPACE::TensorProto_DataType_INT64) { + LOGS(logger, VERBOSE) << "[" << node.OpType() + << "] Input type: [" << input_type + << "] is not supported for now"; + return false; + } + + return true; +} + bool SliceOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& builder_params, const logging::Logger& logger) const { const auto input_defs = node.InputDefs(); diff --git a/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc b/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc index 108bb488ec960..1da9c9df299c9 100644 --- a/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc +++ b/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc @@ -261,15 +261,25 @@ TEST(SliceTest, Slice3D) { 332.0f, 333.0f}); } -TEST(SliceTest, Slice1D_Int) { - RunSliceTest({6}, - {0L, 1L, 2L, 3L, 4L, 5L}, - {2}, - {4}, - {0}, - {}, - {2}, - {2L, 3L}); +template +static void TestSlice1DIntData() { + static_assert(std::is_integral_v); + RunSliceTest({6}, + {0, 1, 2, 3, 4, 5}, + {2}, + {4}, + {0}, + {}, + {2}, + {2, 3}); +} + +TEST(SliceTest, Slice1D_Int32) { + TestSlice1DIntData(); +} + +TEST(SliceTest, Slice1D_Int64) { + TestSlice1DIntData(); } TEST(SliceTest, Slice1D_String) {