Skip to content

Commit

Permalink
Enable int64 slice data input.
Browse files Browse the repository at this point in the history
  • Loading branch information
edgchen1 committed Aug 17, 2023
1 parent 184054b commit cf93e40
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "core/providers/coreml/builders/op_builder_factory.h"
#include "core/providers/coreml/shape_utils.h"
#include "core/providers/cpu/tensor/slice_helper.h"
#include "core/providers/shared/utils/utils.h"

#if defined(__APPLE__)
#include "core/providers/coreml/builders/model_builder.h"
Expand All @@ -32,6 +33,7 @@ class SliceOpBuilder : public BaseOpBuilder {
return 10;
}

bool HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const override;
bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& builder_params,
const logging::Logger& logger) const override;
};
Expand Down Expand Up @@ -145,8 +147,8 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const

if (step < 0 && end == -1) {
// Special case - stepping backwards up to and including the first index in the dimension.
// In ONNX Slice, we can use end = -1 to represent this. In CoreML, endids = -1 doesn't work like that so we can
// use endmasks to specify the rest of the dimension instead.
// In ONNX Slice, we use end <= -(rank + 1) to represent this. In CoreML, setting endids like that doesn't work,
// so use endmasks to specify the rest of the dimension instead.
slice_static->add_endids(-1); // ignored
slice_static->add_endmasks(true);
} else {
Expand All @@ -164,6 +166,22 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
#endif // defined(__APPLE__)

// Operator support related
bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const {
int32_t input_type;
if (!GetType(*node.InputDefs()[0], input_type, logger))
return false;

if (input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT &&
input_type != ONNX_NAMESPACE::TensorProto_DataType_INT64) {
LOGS(logger, VERBOSE) << "[" << node.OpType()
<< "] Input type: [" << input_type
<< "] is not supported for now";
return false;
}

return true;
}

bool SliceOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& builder_params,
const logging::Logger& logger) const {
const auto input_defs = node.InputDefs();
Expand Down
28 changes: 19 additions & 9 deletions onnxruntime/test/providers/cpu/tensor/slice_op.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,15 +261,25 @@ TEST(SliceTest, Slice3D) {
332.0f, 333.0f});
}

TEST(SliceTest, Slice1D_Int) {
RunSliceTest<int32_t>({6},
{0L, 1L, 2L, 3L, 4L, 5L},
{2},
{4},
{0},
{},
{2},
{2L, 3L});
template <typename TInt>
static void TestSlice1DIntData() {
static_assert(std::is_integral_v<TInt>);
RunSliceTest<TInt>({6},
{0, 1, 2, 3, 4, 5},
{2},
{4},
{0},
{},
{2},
{2, 3});
}

TEST(SliceTest, Slice1D_Int32) {
TestSlice1DIntData<int32_t>();
}

TEST(SliceTest, Slice1D_Int64) {
TestSlice1DIntData<int64_t>();
}

TEST(SliceTest, Slice1D_String) {
Expand Down

0 comments on commit cf93e40

Please sign in to comment.