Skip to content

Commit

Permalink
don't provide inputs that are not CoreML model inputs to CoreML
Browse files Browse the repository at this point in the history
  • Loading branch information
edgchen1 committed Aug 17, 2023
1 parent 5af548a commit 2efa90b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
20 changes: 14 additions & 6 deletions onnxruntime/core/providers/coreml/coreml_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ common::Status CoreMLExecutionProvider::Compile(const std::vector<FusedNodeAndGr
for (size_t i = 0, end = input_defs.size(); i < end; ++i) {
onnx_input_names[i] = input_defs[i]->Name();
}
coreml_model->SetInputs(std::move(onnx_input_names));
coreml_model->SetOnnxInputs(std::move(onnx_input_names));
}

{
Expand All @@ -113,7 +113,7 @@ common::Status CoreMLExecutionProvider::Compile(const std::vector<FusedNodeAndGr
for (size_t i = 0, end = output_defs.size(); i < end; ++i) {
onnx_output_names[i] = output_defs[i]->Name();
}
coreml_model->SetOutputs(std::move(onnx_output_names));
coreml_model->SetOnnxOutputs(std::move(onnx_output_names));
}

coreml_models_.emplace(fused_node.Name(), std::move(coreml_model));
Expand All @@ -136,8 +136,8 @@ common::Status CoreMLExecutionProvider::Compile(const std::vector<FusedNodeAndGr
const size_t num_outputs = ctx.GetOutputCount();

coreml::Model* model = reinterpret_cast<coreml::Model*>(state);
const auto& model_inputs = model->GetInputs();
const auto& model_outputs = model->GetOutputs();
const auto& model_inputs = model->GetOnnxInputs();
const auto& model_outputs = model->GetOnnxOutputs();

ORT_RETURN_IF_NOT(model_inputs.size() <= num_inputs, "Inconsistent input sizes");
ORT_RETURN_IF_NOT(model_outputs.size() == num_outputs, "Inconsistent output sizes");
Expand All @@ -146,14 +146,22 @@ common::Status CoreMLExecutionProvider::Compile(const std::vector<FusedNodeAndGr
inputs.reserve(model_inputs.size());
for (size_t i = 0; i < model_inputs.size(); i++) {
const auto& input_name = model_inputs[i];
const auto* input_info = model->TryGetInputOutputInfo(input_name);
if (input_info == nullptr) {
// The CoreML model may not have an actual input that corresponds to this one.
// E.g., when the input is an initializer that already got copied to the CoreML model.
// If there's no CoreML model input, we don't need to provide this input to CoreML.
continue;
}

auto input_tensor = ctx.GetInput(i);
auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo();
auto shape = tensor_info.GetShape();

// Disallow inputs with dynamic shape which actually have zero elements.
// CoreML doesn't consistently handle this well (e.g., there may be runtime errors).
if (const auto* model_input_info = model->TryGetInputOutputInfo(input_name); model_input_info != nullptr) {
const auto& inferred_shape = model_input_info->shape;
{
const auto& inferred_shape = input_info->shape;
ORT_RETURN_IF(!coreml::IsStaticShape(inferred_shape) && coreml::DoesShapeSpecifyZeroElements(shape),
"Input (", input_name, ") has a dynamic shape (", coreml::Shape2String(inferred_shape),
") but the runtime shape (", coreml::Shape2String(shape),
Expand Down
12 changes: 6 additions & 6 deletions onnxruntime/core/providers/coreml/model/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ class Model {
OrtMutex& GetMutex() { return mutex_; }

// Input and output names in the onnx model's order
const std::vector<std::string>& GetInputs() const { return inputs_; }
void SetInputs(std::vector<std::string>&& inputs) { inputs_ = std::move(inputs); }
const std::vector<std::string>& GetOnnxInputs() const { return onnx_inputs_; }
void SetOnnxInputs(std::vector<std::string>&& inputs) { onnx_inputs_ = std::move(inputs); }

const std::vector<std::string>& GetOutputs() const { return outputs_; }
void SetOutputs(std::vector<std::string>&& outputs) { outputs_ = std::move(outputs); }
const std::vector<std::string>& GetOnnxOutputs() const { return onnx_outputs_; }
void SetOnnxOutputs(std::vector<std::string>&& outputs) { onnx_outputs_ = std::move(outputs); }

const OnnxTensorInfo* TryGetInputOutputInfo(const std::string& name) const;
const OnnxTensorInfo& GetInputOutputInfo(const std::string& name) const;
Expand All @@ -65,8 +65,8 @@ class Model {
std::unordered_set<std::string> scalar_outputs_;
std::unordered_set<std::string> int64_outputs_;

std::vector<std::string> inputs_;
std::vector<std::string> outputs_;
std::vector<std::string> onnx_inputs_;
std::vector<std::string> onnx_outputs_;

std::unordered_map<std::string, OnnxTensorInfo> input_output_info_;

Expand Down

0 comments on commit 2efa90b

Please sign in to comment.