Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Systematically expand ONNX functions before conversion to Torch to avoid needing bespoke conversions #3384

Closed
andfau-amd opened this issue May 23, 2024 · 12 comments · Fixed by #3409
Assignees

Comments

@andfau-amd
Copy link
Contributor

andfau-amd commented May 23, 2024

Currently there are a large number of operators in ONNX that aren't supported by TorchOnnxToTorch. As a beginner, I was tasked with adding a conversion for one of those, namely MeanVarianceNormalization, in nod-ai/SHARK-ModelDev#697. But I noticed that this operator is actually a function, and operators that are functions are defined by a series of other ONNX operators. In this case, all those other operators are ones we already have conversions for. So I wondered why we don't expand these operator functions before converting, to avoid needing bespoke conversions for all of them.

I discussed a bit with some colleagues (@ScottTodd, @rsuderman) and it doesn't seem like there's a good reason not to do this, so I'm going to make an attempt at it. Here's some things we've learned that might be relevant:

Regarding actually implementing an expansion in our importer or converter, some things to keep track of:

  • Once we have expansion working, there should be an assertion somewhere to catch cases where operators we expect to be expanded are somehow still present in the IR.
  • In case we want to have a bespoke conversion for some particular operator, there should be some allowlist/denylist setup to prevent some operators being expanded, so they can be converted later.
  • There should be some instrumentation on the expansion, so it's easy to see which operators are getting expanded, enabling decision-making about possible future special-casing.

If we do implement an expansion, then:

  • Should we get rid of any existing bespoke lowerings that are now redundant? I know there's at least one. Maybe something for follow-up work.
@andfau-amd
Copy link
Contributor Author

andfau-amd commented May 27, 2024

The ONNX API has a function called FunctionExpandHelper (in Python: function_expand_helper)

ONNX uses [the Python version of] it for generating expanded versions of tests

Oops, that's not the same function! I'd missed this:

https://github.com/onnx/onnx/blob/b90e252da11dea9bdc191d6b9b8d01511ef3e3bd/onnx/backend/test/case/node/__init__.py#L92-L95

So this Python variant of this function only exists internally to this test code. This might not be good for the idea of doing the expansion in our Python converter (python/torch_mlir/tools/import_onnx/__main__.py), but we'll see.

@stellaraccident
Copy link
Collaborator

I haven't looked at this super closely, but in my mind at the beginning, I thought the importer could just be taught to import such functions as private func.func and then emit a func.call to them when used. I think that would match semantics and leave most of the work in MLIR.

@andfau-amd
Copy link
Contributor Author

andfau-amd commented May 27, 2024

Oh, that would actually suit me pretty well. Because the importer is written in Python, and there's no "proper" way to do single-step in-place expansion exposed by the public ONNX Python API right now, I'm currently attempting to do it in two steps: turn these ops into local functions first, then use the ONNX inlining API. But I could skip the inlining at the ONNX level and import them as MLIR local functions. That would give more readable MLIR actually (edit: and come to think of it, also avoid redundant work in the importer), it seems like the better idea.

@stellaraccident
Copy link
Collaborator

Yeah, this is how some other frontends do it: get it in mlir and let the infra there take over.

Also keeps the importer mechanical.

@andfau-amd
Copy link
Contributor Author

It was more complicated than I expected, and I might have missed some things, but I managed to get something working!

I can get it to spit out an MLIR function for the relevant ONNX operator and call it:

python -m torch_mlir.tools.import_onnx mlir_venv/lib/python3.10/site-packages/onnx/backend/test/data/node/test_mvn/model.onnx
module {
  func.func @test_mvn(%arg0: !torch.vtensor<[3,3,3,1],f32>) -> !torch.vtensor<[3,3,3,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
    %none = torch.constant.none
    %0 = call @"('MeanVarianceNormalization', '', 18, [tensor_type {\0A  elem_type: 1\0A  shape {\0A    dim {\0A      dim_value: 3\0A    }\0A    dim {\0A      dim_value: 3\0A    }\0A    dim {\0A      dim_value: 3\0A    }\0A    dim {\0A      dim_value: 1\0A    }\0A  }\0A}\0A], [tensor_type {\0A  elem_type: 1\0A  shape {\0A    dim {\0A      dim_value: 3\0A    }\0A    dim {\0A      dim_value: 3\0A    }\0A    dim {\0A      dim_value: 3\0A    }\0A    dim {\0A      dim_value: 1\0A    }\0A  }\0A}\0A], [])"(%arg0) : (!torch.vtensor<[3,3,3,1],f32>) -> !torch.vtensor<[3,3,3,1],f32>
    return %0 : !torch.vtensor<[3,3,3,1],f32>
  }
  func.func @"('MeanVarianceNormalization', '', 18, [tensor_type {\0A  elem_type: 1\0A  shape {\0A    dim {\0A      dim_value: 3\0A    }\0A    dim {\0A      dim_value: 3\0A    }\0A    dim {\0A      dim_value: 3\0A    }\0A    dim {\0A      dim_value: 1\0A    }\0A  }\0A}\0A], [tensor_type {\0A  elem_type: 1\0A  shape {\0A    dim {\0A      dim_value: 3\0A    }\0A    dim {\0A      dim_value: 3\0A    }\0A    dim {\0A      dim_value: 3\0A    }\0A    dim {\0A      dim_value: 1\0A    }\0A  }\0A}\0A], [])"(%arg0: !torch.vtensor<[3,3,3,1],f32>) -> !torch.vtensor<[3,3,3,1],f32> attributes {torch.onnx_meta.ir_version = 0 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
    %none = torch.constant.none
    %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<2.000000e+00> : tensor<f32>} : () -> !torch.vtensor<[],f32>
    %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<9.99999971E-10> : tensor<f32>} : () -> !torch.vtensor<[],f32>
    %2 = torch.operator "onnx.Constant"() {torch.onnx.value_ints = [0 : si64, 2 : si64, 3 : si64]} : () -> !torch.vtensor<[3],si64>
    %3 = torch.operator "onnx.ReduceMean"(%arg0, %2) : (!torch.vtensor<[3,3,3,1],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[1,3,1,1],f32>
    %4 = torch.operator "onnx.Pow"(%3, %0) : (!torch.vtensor<[1,3,1,1],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[1,3,1,1],f32>
    %5 = torch.operator "onnx.Pow"(%arg0, %0) : (!torch.vtensor<[3,3,3,1],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[3,3,3,1],f32>
    %6 = torch.operator "onnx.ReduceMean"(%5, %2) : (!torch.vtensor<[3,3,3,1],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[1,3,1,1],f32>
    %7 = torch.operator "onnx.Sub"(%6, %4) : (!torch.vtensor<[1,3,1,1],f32>, !torch.vtensor<[1,3,1,1],f32>) -> !torch.vtensor<[1,3,1,1],f32>
    %8 = torch.operator "onnx.Sqrt"(%7) : (!torch.vtensor<[1,3,1,1],f32>) -> !torch.vtensor<[1,3,1,1],f32>
    %9 = torch.operator "onnx.Sub"(%arg0, %3) : (!torch.vtensor<[3,3,3,1],f32>, !torch.vtensor<[1,3,1,1],f32>) -> !torch.vtensor<[3,3,3,1],f32>
    %10 = torch.operator "onnx.Add"(%8, %1) : (!torch.vtensor<[1,3,1,1],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[1,3,1,1],f32>
    %11 = torch.operator "onnx.Div"(%9, %10) : (!torch.vtensor<[3,3,3,1],f32>, !torch.vtensor<[1,3,1,1],f32>) -> !torch.vtensor<[3,3,3,1],f32>
    return %11 : !torch.vtensor<[3,3,3,1],f32>
  }
}

I apologise for the very ugly function name. Hopefully a slightly more pleasant name mangling scheme is possible. :)

Then if I change the importer so that the main function is marked as public, and the operator's function is marked as private, MLIR can inline it and throw it away:

python -m torch_mlir.tools.import_onnx mlir_venv/lib/python3.10/site-packages/onnx/backend/test/data/node/test_mvn/model.onnx | build/bin/torch-mlir-opt --split-input-file --inline
module {
  func.func public @test_mvn(%arg0: !torch.vtensor<[3,3,3,1],f32>) -> !torch.vtensor<[3,3,3,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
    %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<2.000000e+00> : tensor<f32>} : () -> !torch.vtensor<[],f32>
    %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<9.99999971E-10> : tensor<f32>} : () -> !torch.vtensor<[],f32>
    %2 = torch.operator "onnx.Constant"() {torch.onnx.value_ints = [0 : si64, 2 : si64, 3 : si64]} : () -> !torch.vtensor<[3],si64>
    %3 = torch.operator "onnx.ReduceMean"(%arg0, %2) : (!torch.vtensor<[3,3,3,1],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[1,3,1,1],f32>
    %4 = torch.operator "onnx.Pow"(%3, %0) : (!torch.vtensor<[1,3,1,1],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[1,3,1,1],f32>
    %5 = torch.operator "onnx.Pow"(%arg0, %0) : (!torch.vtensor<[3,3,3,1],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[3,3,3,1],f32>
    %6 = torch.operator "onnx.ReduceMean"(%5, %2) : (!torch.vtensor<[3,3,3,1],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[1,3,1,1],f32>
    %7 = torch.operator "onnx.Sub"(%6, %4) : (!torch.vtensor<[1,3,1,1],f32>, !torch.vtensor<[1,3,1,1],f32>) -> !torch.vtensor<[1,3,1,1],f32>
    %8 = torch.operator "onnx.Sqrt"(%7) : (!torch.vtensor<[1,3,1,1],f32>) -> !torch.vtensor<[1,3,1,1],f32>
    %9 = torch.operator "onnx.Sub"(%arg0, %3) : (!torch.vtensor<[3,3,3,1],f32>, !torch.vtensor<[1,3,1,1],f32>) -> !torch.vtensor<[3,3,3,1],f32>
    %10 = torch.operator "onnx.Add"(%8, %1) : (!torch.vtensor<[1,3,1,1],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[1,3,1,1],f32>
    %11 = torch.operator "onnx.Div"(%9, %10) : (!torch.vtensor<[3,3,3,1],f32>, !torch.vtensor<[1,3,1,1],f32>) -> !torch.vtensor<[3,3,3,1],f32>
    return %11 : !torch.vtensor<[3,3,3,1],f32>
  }
}

And this looks fairly similar at first glance to the known-good expansion, though I'm not yet sure if it's perfectly equivalent.

You can see what I have so far at main...andfau-amd:torch-mlir:onnx-to-torch-function-expansion. The interesting stuff is in get_operator_function, the rest is mostly plumbing.

(Perhaps I should make a draft PR? I think it would be useful to get some early feedback.)

@andfau-amd
Copy link
Contributor Author

andfau-amd commented May 29, 2024

Resolved my FIXMEs, now the MLIR perfectly matches for MeanVarianceNormalization once it's inlined and cleaned up:

$ python -m torch_mlir.tools.import_onnx mlir_venv/lib/python3.10/site-packages/onnx/backend/test/data/node/test_mvn/model.onnx | build/bin/torch-mlir-opt --split-input-file --inline --convert-torch-onnx-to-torch --cse --canonicalize > mvn-ours.txt
$ python -m torch_mlir.tools.import_onnx mlir_venv/lib/python3.10/site-packages/onnx/backend/test/data/node/test_mvn_expanded/model.onnx | build/bin/torch-mlir-opt --split-input-file --inline --convert-torch-onnx-to-torch --cse --canonicalize > mvn-theirs.txt
$ sed -e 's/_expanded//g' mvn-theirs.txt > mvn-theirs2.txt
$ diff mvn-ours.txt mvn-theirs2.txt
$

I'll try to add some tests and then I'll open a pull request.

@renxida
Copy link
Collaborator

renxida commented May 29, 2024

This is awesome

@andfau-amd
Copy link
Contributor Author

Thank you!

@andfau-amd
Copy link
Contributor Author

I've now gotten to the point where cmake --build build --target check-torch-mlir no longer fails with my patch. Lots of edge cases had to be fixed that were causing the importer to crash in pre-existing regression tests. I think I know too much about ONNX now. :) I have actually had to denylist two operators to prevent them being expanded, for fairly good reasons. :(

With that out of the way I can actually start adding new tests.

andfau-amd added a commit to andfau-amd/torch-mlir that referenced this issue May 31, 2024
Resolves llvm#3384.

Many ONNX operators are defined by functions and therefore could be
expanded into simpler ONNX operations during importing, avoiding the
need for tools downstream to support these operators directly.

This commit changes onnx_importer.py to systematically perform this
expansion for all ONNX operators that are not explicitly denylisted.
When importing a node, the schema for the node's operation is retrieved.
If the schema provides a function for the operator, a specialized
version for the node's types and attributes will be created and imported
as an MLIR function with private visibility. An MLIR function call will
then be omitted, instead of a normal operator node. Caching is used to
avoid generating redundant functions within the same module.

Note that previously all MLIR functions generated by the importer had no
visibility specified. This commit changes this: the main function for a
model is now public. This is so that the MLIR inliner pass will
automatically discard the (private) operator functions after inlining.

Explanations for subtle code changes:

- Looking up the correct schema and function for an operator requires
  knowing the opset version. NodeImporter retrieves this from the
  opset imports on the ModelProto retained by the GraphInfo. Previously,
  the model_proto field on GraphInfo was None when importing a subgraph
  in import_regions, but this conflicts with the new need for opset
  version info. Since the apparent purpose of setting it to None was to
  control how GraphInfo generates its input map, a new flag is added to
  GraphInfo (is_subgraph) to control this behavior, so that the actual
  ModelProto can now be provided without breaking this.
- Some operators' functions are context-dependent, which means the
  function definition depends on the types of the inputs. Therefore node
  importing now needs to look up the types of a node's inputs, not just
  its outputs as was the case previously. Consequently
  find_type_proto_for_name() now gets called on graph inputs, not just
  intermediate values and graph outputs, so it has to be updated.
@andfau-amd
Copy link
Contributor Author

Created a PR to track things beyond this point and allow commenting on the code etc. It's not quite ready for merging yet though. #3409

andfau-amd added a commit to andfau-amd/torch-mlir that referenced this issue Jun 4, 2024
Resolves llvm#3384.

Many ONNX operators are defined by functions and therefore could be
expanded into simpler ONNX operations during importing, avoiding the
need for tools downstream to support these operators directly.

This commit changes onnx_importer.py to systematically perform this
expansion for all ONNX operators that are not explicitly denylisted.
When importing a node, the schema for the node's operation is retrieved.
If the schema provides a function for the operator, a specialized
version for the node's types and attributes will be created and imported
as an MLIR function with private visibility. An MLIR function call will
then be omitted, instead of a normal operator node. Caching is used to
avoid generating redundant functions within the same module.

Note that previously all MLIR functions generated by the importer had no
visibility specified. This commit changes this: the main function for a
model is now public. This is so that the MLIR inliner pass will
automatically discard the (private) operator functions after inlining.

Explanations for subtle code changes:

- Looking up the correct schema and function for an operator requires
  knowing the opset version. NodeImporter retrieves this from the
  opset imports on the ModelProto retained by the GraphInfo. Previously,
  the model_proto field on GraphInfo was None when importing a subgraph
  in import_regions, but this conflicts with the new need for opset
  version info. Since the apparent purpose of setting it to None was to
  control how GraphInfo generates its input map, a new flag is added to
  GraphInfo (is_subgraph) to control this behavior, so that the actual
  ModelProto can now be provided without breaking this.
- Some operators' functions are context-dependent, which means the
  function definition depends on the types of the inputs. Therefore node
  importing now needs to look up the types of a node's inputs, not just
  its outputs as was the case previously. Consequently the operand to
  find_type_proto_for_name() may now be a graph input or initializer in
  some cases, so it has to be updated.
andfau-amd added a commit to andfau-amd/torch-mlir that referenced this issue Jun 4, 2024
Resolves llvm#3384.

Many ONNX operators are defined by functions and therefore could be
expanded into simpler ONNX operations during importing, avoiding the
need for tools downstream to support these operators directly.

This commit changes onnx_importer.py to systematically perform this
expansion for all ONNX operators that are not explicitly denylisted.
When importing a node, the schema for the node's operation is retrieved.
If the schema provides a function for the operator, a specialized
version for the node's types and attributes will be created and imported
as an MLIR function with private visibility. An MLIR function call will
then be omitted, instead of a normal operator node. Caching is used to
avoid generating redundant functions within the same module.

Note that previously all MLIR functions generated by the importer had no
visibility specified. This commit changes this: the main function for a
model is now public. This is so that the MLIR inliner pass will
automatically discard the (private) operator functions after inlining.

Some consequences for things downstream of the importer:

- Inlining should now be done before doing any lowering, for example
  `torch-mlir-opt --inline --convert-onnx-to-torch`.
- Some lowerings in TorchOnnxToTorch are now redundant and perhaps can
  be removed.

Explanations for subtle code changes:

- Looking up the correct schema and function for an operator requires
  knowing the opset version. NodeImporter retrieves this from the
  opset imports on the ModelProto retained by the GraphInfo. Previously,
  the model_proto field on GraphInfo was None when importing a subgraph
  in import_regions, but this conflicts with the new need for opset
  version info. Since the apparent purpose of setting it to None was to
  control how GraphInfo generates its input map, a new flag is added to
  GraphInfo (is_subgraph) to control this behavior, so that the actual
  ModelProto can now be provided without breaking this. This also turned
  out to be useful for getting the Config via ModelInfo via GraphInfo.
- Some operators' functions are context-dependent, which means the
  function definition depends on the types of the inputs. Therefore node
  importing now needs to look up the types of a node's inputs, not just
  its outputs as was the case previously. Consequently the operand to
  find_type_proto_for_name() may now be a graph input or initializer in
  some cases, so it has to be updated.
@andfau-amd
Copy link
Contributor Author

Marked the PR as ready for review. It still needs some more testing but that can be done in parallel I hope.

andfau-amd added a commit to andfau-amd/torch-mlir that referenced this issue Jun 5, 2024
Resolves llvm#3384.

Many ONNX operators are defined by functions and therefore could be
expanded into simpler ONNX operations during importing, avoiding the
need for tools downstream to support these operators directly.

This commit changes onnx_importer.py to systematically perform this
expansion for all ONNX operators that are not explicitly denylisted.
When importing a node, the schema for the node's operation is retrieved.
If the schema provides a function for the operator, a specialized
version for the node's types and attributes will be created and imported
as an MLIR function with private visibility. An MLIR function call will
then be omitted, instead of a normal operator node. Caching is used to
avoid generating redundant functions within the same module.

Note that previously all MLIR functions generated by the importer had no
visibility specified. This commit changes this: the main function for a
model is now public. This is so that the MLIR inliner pass will
automatically discard the (private) operator functions after inlining.

Some consequences for things downstream of the importer:

- Inlining should now be done before doing any lowering, for example
  `torch-mlir-opt --inline --convert-onnx-to-torch`.
- Some lowerings in TorchOnnxToTorch are now redundant and perhaps can
  be removed.

Explanations for subtle code changes:

- Looking up the correct schema and function for an operator requires
  knowing the opset version. NodeImporter retrieves this from the
  opset imports on the ModelProto retained by the GraphInfo. Previously,
  the model_proto field on GraphInfo was None when importing a subgraph
  in import_regions, but this conflicts with the new need for opset
  version info. Since the apparent purpose of setting it to None was to
  control how GraphInfo generates its input map, a new flag is added to
  GraphInfo (is_subgraph) to control this behavior, so that the actual
  ModelProto can now be provided without breaking this. This also turned
  out to be useful for getting the Config via ModelInfo via GraphInfo.
- Some operators' functions are context-dependent, which means the
  function definition depends on the types of the inputs. Therefore node
  importing now needs to look up the types of a node's inputs, not just
  its outputs as was the case previously. Consequently the operand to
  find_type_proto_for_name() may now be a graph input or initializer in
  some cases, so it has to be updated.
andfau-amd added a commit to andfau-amd/torch-mlir that referenced this issue Jun 14, 2024
Resolves llvm#3384.

Many ONNX operators are defined by functions and therefore could be
expanded into simpler ONNX operations during importing, avoiding the
need for tools downstream to support these operators directly.

This commit adds this capability to onnx_importer.py. When importing a
node, the schema for the node's operation is retrieved. If the schema
provides a function for the operator, a specialized version for the
node's types and attributes will be created and imported as an MLIR
function with private visibility. An MLIR function call will then be
emitted, instead of a normal operator node. Caching is used to avoid
generating redundant functions within the same module.

In order to avoid a disruptive change to the importer output for a
large number of operations that already have TorchOnnxToTorch support,
an allowlist strategy is used by default. With this commit, only two
operations are allowlisted for expansion: MeanVarianceNormalization and
NegativeLogLikelihoodLoss. Hopefully this list can be gradually
expanded. It is possible to disable the allowlist in the configuration,
in which case all functions are expanded (useful for testing).

Tools downstream of the importer may now need to do inlining when
consuming the output of the importer, e.g.:

  cat imported.mlir | torch-mlir-opt --inline --convert-onnx-to-torch

Explanations for subtle code changes:

- Looking up the correct schema and function for an operator requires
  knowing the opset version. NodeImporter retrieves this from the
  opset imports on the ModelProto retained by the GraphInfo. Previously,
  the model_proto field on GraphInfo was None when importing a subgraph
  in import_regions, but this conflicts with the new need for opset
  version info. Since the apparent purpose of setting it to None was to
  control how GraphInfo generates its input map, a new flag is added to
  GraphInfo (is_subgraph) to control this behavior, so that the actual
  ModelProto can now be provided without breaking this. This also turned
  out to be useful for getting the Config via ModelInfo via GraphInfo.
- Some operators' functions are context-dependent, which means the
  function definition depends on the types of the inputs. Therefore node
  importing now needs to look up the types of a node's inputs, not just
  its outputs as was the case previously. Consequently the operand to
  find_type_proto_for_name() may now be a graph input or initializer in
  some cases, so it has to be updated.
andfau-amd added a commit to andfau-amd/torch-mlir that referenced this issue Jun 14, 2024
Resolves llvm#3384.

Many ONNX operators are defined by functions and therefore could be
expanded into simpler ONNX operations during importing, avoiding the
need for tools downstream to support these operators directly.

This commit adds this capability to onnx_importer.py. When importing a
node, the schema for the node's operator is retrieved. If the schema
provides a function for the operator, a specialized version for the
node's types and attributes will be created and imported as an MLIR
function with private visibility. An MLIR function call will then be
emitted, instead of a normal operator node. Caching is used to avoid
generating redundant functions within the same module.

In order to avoid a disruptive change to the importer output for a
large number of operators that already have TorchOnnxToTorch support,
an allowlist strategy is used by default. With this commit, only two
operators are allowlisted for expansion: MeanVarianceNormalization and
NegativeLogLikelihoodLoss. Hopefully this list can be gradually
expanded. It is possible to disable the allowlist in the configuration,
in which case all functions are expanded (useful for testing).

Tools downstream of the importer may now need to do inlining when
consuming the output of the importer, e.g.:

  cat imported.mlir | torch-mlir-opt --inline --convert-onnx-to-torch

Explanations for subtle code changes:

- Looking up the correct schema and function for an operator requires
  knowing the opset version. NodeImporter retrieves this from the
  opset imports on the ModelProto retained by the GraphInfo. Previously,
  the model_proto field on GraphInfo was None when importing a subgraph
  in import_regions, but this conflicts with the new need for opset
  version info. Since the apparent purpose of setting it to None was to
  control how GraphInfo generates its input map, a new flag is added to
  GraphInfo (is_subgraph) to control this behavior, so that the actual
  ModelProto can now be provided without breaking this. This also turned
  out to be useful for getting the Config via ModelInfo via GraphInfo.
- Some operators' functions are context-dependent, which means the
  function definition depends on the types of the inputs. Therefore node
  importing now needs to look up the types of a node's inputs, not just
  its outputs as was the case previously. Consequently the operand to
  find_type_proto_for_name() may now be a graph input or initializer in
  some cases, so it has to be updated.
andfau-amd added a commit to andfau-amd/torch-mlir that referenced this issue Jun 14, 2024
Resolves llvm#3384.

Many ONNX operators are defined by functions and therefore could be
expanded into simpler ONNX operations during importing, avoiding the
need for tools downstream to support these operators directly.

This commit adds this capability to onnx_importer.py. When importing a
node, the schema for the node's operator is retrieved. If the schema
provides a function for the operator, a specialized version for the
node's types and attributes will be created and imported as an MLIR
function with private visibility. An MLIR function call will then be
emitted, instead of a normal operator node. Caching is used to avoid
generating redundant functions within the same module.

In order to avoid a disruptive change to the importer output for a
large number of operators that already have TorchOnnxToTorch support,
an allowlist strategy is used by default. With this commit, only one
operator is allowlisted for expansion, MeanVarianceNormalization.
However, many other operators can be correctly expanded by the current
code, so hopefully the allowlist can be gradually extended. It is
possible to disable the allowlist in the configuration, in which case
all functions are expanded (useful for testing).

Tools downstream of the importer may now need to do inlining when
consuming the output of the importer, e.g.:

  cat imported.mlir | torch-mlir-opt --inline --convert-onnx-to-torch

Explanations for subtle code changes:

- Looking up the correct schema and function for an operator requires
  knowing the opset version. NodeImporter retrieves this from the
  opset imports on the ModelProto retained by the GraphInfo. Previously,
  the model_proto field on GraphInfo was None when importing a subgraph
  in import_regions, but this conflicts with the new need for opset
  version info. Since the apparent purpose of setting it to None was to
  control how GraphInfo generates its input map, a new flag is added to
  GraphInfo (is_subgraph) to control this behavior, so that the actual
  ModelProto can now be provided without breaking this. This also turned
  out to be useful for getting the Config via ModelInfo via GraphInfo.
- Some operators' functions are context-dependent, which means the
  function definition depends on the types of the inputs. Therefore node
  importing now needs to look up the types of a node's inputs, not just
  its outputs as was the case previously. Consequently the operand to
  find_type_proto_for_name() may now be a graph input or initializer in
  some cases, so it has to be updated.
@andfau-amd
Copy link
Contributor Author

Follow-up issue for expanding the allowlist: #3464

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants