forked from llvm/torch-mlir
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create MLIR functions for ONNX operators that are functions
Resolves llvm#3384. Many ONNX operators are defined by functions and therefore could be expanded into simpler ONNX operations during importing, avoiding the need for tools downstream to support these operators directly. This commit adds this capability to onnx_importer.py. When importing a node, the schema for the node's operator is retrieved. If the schema provides a function for the operator, a specialized version for the node's types and attributes will be created and imported as an MLIR function with private visibility. An MLIR function call will then be emitted, instead of a normal operator node. Caching is used to avoid generating redundant functions within the same module. In order to avoid a disruptive change to the importer output for a large number of operators that already have TorchOnnxToTorch support, an allowlist strategy is used by default. With this commit, only two operators are allowlisted for expansion: MeanVarianceNormalization and NegativeLogLikelihoodLoss. Hopefully this list can be gradually expanded. It is possible to disable the allowlist in the configuration, in which case all functions are expanded (useful for testing). Tools downstream of the importer may now need to do inlining when consuming the output of the importer, e.g.: cat imported.mlir | torch-mlir-opt --inline --convert-onnx-to-torch Explanations for subtle code changes: - Looking up the correct schema and function for an operator requires knowing the opset version. NodeImporter retrieves this from the opset imports on the ModelProto retained by the GraphInfo. Previously, the model_proto field on GraphInfo was None when importing a subgraph in import_regions, but this conflicts with the new need for opset version info. Since the apparent purpose of setting it to None was to control how GraphInfo generates its input map, a new flag is added to GraphInfo (is_subgraph) to control this behavior, so that the actual ModelProto can now be provided without breaking this. This also turned out to be useful for getting the Config via ModelInfo via GraphInfo. - Some operators' functions are context-dependent, which means the function definition depends on the types of the inputs. Therefore node importing now needs to look up the types of a node's inputs, not just its outputs as was the case previously. Consequently the operand to find_type_proto_for_name() may now be a graph input or initializer in some cases, so it has to be updated.
- Loading branch information
1 parent
a02e14e
commit 66920b8
Showing
9 changed files
with
547 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
18 changes: 18 additions & 0 deletions
18
test/python/onnx_importer/function_expansion/GreaterOrEqual.runlit
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Test that expansion of ONNX operators that are functions works for a simple | ||
# example. The exact name mangling scheme used is not matched against, all that | ||
# matters is that it has the name of the operator (GreaterOrEqual here) in it. | ||
# Attributes are also not checked here. What we are interested in is the types | ||
# and operations. | ||
# | ||
# The model comes from an upstream ONNX test: backend/test/data/node/test_greater_equal/model.onnx | ||
|
||
# RUN: %PYTHON -m torch_mlir.tools.import_onnx --disable-function-expansion-allowlist %s.onnx | FileCheck %s | ||
|
||
# CHECK-LABEL: func.func @test_greater_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> | ||
# CHECK: %0 = call @"{{.*}}GreaterOrEqual{{.*}}"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> | ||
|
||
# CHECK-LABEL: func.func private @"{{.*}}GreaterOrEqual{{.*}}"(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> | ||
# CHECK: %0 = torch.operator "onnx.Greater"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> | ||
# CHECK: %1 = torch.operator "onnx.Equal"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> | ||
# CHECK: %2 = torch.operator "onnx.Or"(%0, %1) : (!torch.vtensor<[3,4,5],i1>, !torch.vtensor<[3,4,5],i1>) -> !torch.vtensor<[3,4,5],i1> | ||
# CHECK: return %2 : !torch.vtensor<[3,4,5],i1> |
Binary file added
BIN
+171 Bytes
test/python/onnx_importer/function_expansion/GreaterOrEqual.runlit.onnx
Binary file not shown.
22 changes: 22 additions & 0 deletions
22
test/python/onnx_importer/function_expansion/ReduceSumSquare_keepdims=0.runlit
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Test the expansion of ONNX operators that are functions, specifically the | ||
# propagation of attribute values from the call-site to nodes within the | ||
# expanded function. | ||
# | ||
# In this case, the model has a ReduceSumSquare node with the attribute | ||
# 'keepdims' set to 0, and the definition of this version of ReduceSumSquare | ||
# contains a ReduceSum node that references the value of 'keepdims', so we | ||
# expect to see this value propagated to the ReduceSum node in the expansion. | ||
# | ||
# This also tests that the absence of 'axes' (as an optional attribute with no | ||
# default value) is propagated in the same way. | ||
# | ||
# The model comes from an upstream ONNX test: backend/test/data/node/test_reduce_sum_square_do_not_keepdims_example/model.onnx | ||
|
||
# RUN: %PYTHON -m torch_mlir.tools.import_onnx --disable-function-expansion-allowlist %s.onnx | FileCheck %s | ||
# | ||
# CHECK-LABEL: func.func @test_reduce_sum_square_do_not_keepdims_example | ||
# CHECK: %0 = call @"{{.*}}ReduceSumSquare{{.*}}" | ||
# | ||
# CHECK-LABEL: func.func private @"{{.*}}ReduceSumSquare{{.*}}" | ||
# CHECK: %0 = torch.operator "onnx.Mul" | ||
# CHECK: %1 = torch.operator "onnx.ReduceSum"{{.*}}{torch.onnx.keepdims = 0 : si64} |
Binary file added
BIN
+205 Bytes
test/python/onnx_importer/function_expansion/ReduceSumSquare_keepdims=0.runlit.onnx
Binary file not shown.
23 changes: 23 additions & 0 deletions
23
test/python/onnx_importer/function_expansion/ReduceSumSquare_no_attrs.runlit
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Test the expansion of ONNX operators that are functions, specifically the | ||
# propagation of attribute values from the call-site to nodes within the | ||
# expanded function. | ||
# | ||
# In this case, the model has a ReduceSumSquare node with no attributes, but the | ||
# definition of this version of ReduceSumSquare contains a ReduceSum node that | ||
# references the value of 'keepdims', and the definition says its default value | ||
# is 1, so we expect to see this value propagated to the ReduceSum node in the | ||
# expansion. | ||
# | ||
# This also tests that the absence of 'axes' (as an optional attribute with no | ||
# default value) is propagated in the same way. | ||
# | ||
# The model comes from an upstream ONNX test: backend/test/data/node/test_reduce_sum_square_empty_set/model.onnx | ||
|
||
# RUN: %PYTHON -m torch_mlir.tools.import_onnx --disable-function-expansion-allowlist %s.onnx | FileCheck %s | ||
# | ||
# CHECK-LABEL: func.func @test_reduce_sum_square_empty_set | ||
# CHECK: %0 = call @"{{.*}}ReduceSumSquare{{.*}}" | ||
# | ||
# CHECK-LABEL: func.func private @"{{.*}}ReduceSumSquare{{.*}}" | ||
# CHECK: %0 = torch.operator "onnx.Mul" | ||
# CHECK: %1 = torch.operator "onnx.ReduceSum"{{.*}}{torch.onnx.keepdims = 1 : si64} |
Binary file added
BIN
+195 Bytes
test/python/onnx_importer/function_expansion/ReduceSumSquare_no_attrs.runlit.onnx
Binary file not shown.