Skip to content

Commit

Permalink
Create MLIR functions for ONNX operators that are functions
Browse files Browse the repository at this point in the history
Resolves llvm#3384.

Many ONNX operators are defined by functions and therefore could be
expanded into simpler ONNX operations during importing, avoiding the
need for tools downstream to support these operators directly.

This commit adds this capability to onnx_importer.py. When importing a
node, the schema for the node's operator is retrieved. If the schema
provides a function for the operator, a specialized version for the
node's types and attributes will be created and imported as an MLIR
function with private visibility. An MLIR function call will then be
emitted, instead of a normal operator node. Caching is used to avoid
generating redundant functions within the same module.

In order to avoid a disruptive change to the importer output for a
large number of operators that already have TorchOnnxToTorch support,
an allowlist strategy is used by default. With this commit, only two
operators are allowlisted for expansion: MeanVarianceNormalization and
NegativeLogLikelihoodLoss. Hopefully this list can be gradually
expanded. It is possible to disable the allowlist in the configuration,
in which case all functions are expanded (useful for testing).

Tools downstream of the importer may now need to do inlining when
consuming the output of the importer, e.g.:

  cat imported.mlir | torch-mlir-opt --inline --convert-onnx-to-torch

Explanations for subtle code changes:

- Looking up the correct schema and function for an operator requires
  knowing the opset version. NodeImporter retrieves this from the
  opset imports on the ModelProto retained by the GraphInfo. Previously,
  the model_proto field on GraphInfo was None when importing a subgraph
  in import_regions, but this conflicts with the new need for opset
  version info. Since the apparent purpose of setting it to None was to
  control how GraphInfo generates its input map, a new flag is added to
  GraphInfo (is_subgraph) to control this behavior, so that the actual
  ModelProto can now be provided without breaking this. This also turned
  out to be useful for getting the Config via ModelInfo via GraphInfo.
- Some operators' functions are context-dependent, which means the
  function definition depends on the types of the inputs. Therefore node
  importing now needs to look up the types of a node's inputs, not just
  its outputs as was the case previously. Consequently the operand to
  find_type_proto_for_name() may now be a graph input or initializer in
  some cases, so it has to be updated.
  • Loading branch information
andfau-amd committed Jun 14, 2024
1 parent a02e14e commit 66920b8
Show file tree
Hide file tree
Showing 9 changed files with 547 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,10 @@ def _module_lowering(
# Lower from ONNX to Torch
run_pipeline_with_repro_report(
torch_mod,
f"builtin.module(func.func({ONNX_TO_TORCH_FUNC_PIPELINE}))",
# The importer may produce additional MLIR functions corresponding to
# ONNX operators that are functions. In some cases they need to be
# inlined to avoid the backend choking on them.
f"builtin.module(inline, func.func({ONNX_TO_TORCH_FUNC_PIPELINE}))",
"Lowering Onnx backend contract to Linalg-on-Tensors backend contract",
)

Expand Down
496 changes: 469 additions & 27 deletions python/torch_mlir/extras/onnx_importer.py

Large diffs are not rendered by default.

12 changes: 11 additions & 1 deletion python/torch_mlir/tools/import_onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,14 @@


def main(args: argparse.Namespace):
config = onnx_importer.Config()
if args.disable_function_expansion_allowlist:
config.function_expansion_allowlists_by_domain = None

model_proto = load_onnx_model(args)
context = Context()
torch_d.register_dialect(context)
model_info = onnx_importer.ModelInfo(model_proto)
model_info = onnx_importer.ModelInfo(model_proto, config=config)
m = model_info.create_module(context=context).operation
imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m)
imp.import_all()
Expand Down Expand Up @@ -195,6 +199,12 @@ def parse_arguments(argv=None) -> argparse.Namespace:
" to before importing to MLIR. This can sometime assist with shape inference.",
type=int,
)
parser.add_argument(
"--disable-function-expansion-allowlist",
action="store_true",
help="Disable the allowlist for ONNX function expansion,"
" allowing non-allowlisted functions to be expanded.",
)
args = parser.parse_args(argv)
return args

Expand Down
18 changes: 18 additions & 0 deletions test/python/onnx_importer/function_expansion/GreaterOrEqual.runlit
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Test that expansion of ONNX operators that are functions works for a simple
# example. The exact name mangling scheme used is not matched against, all that
# matters is that it has the name of the operator (GreaterOrEqual here) in it.
# Attributes are also not checked here. What we are interested in is the types
# and operations.
#
# The model comes from an upstream ONNX test: backend/test/data/node/test_greater_equal/model.onnx

# RUN: %PYTHON -m torch_mlir.tools.import_onnx --disable-function-expansion-allowlist %s.onnx | FileCheck %s

# CHECK-LABEL: func.func @test_greater_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1>
# CHECK: %0 = call @"{{.*}}GreaterOrEqual{{.*}}"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1>

# CHECK-LABEL: func.func private @"{{.*}}GreaterOrEqual{{.*}}"(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1>
# CHECK: %0 = torch.operator "onnx.Greater"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1>
# CHECK: %1 = torch.operator "onnx.Equal"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1>
# CHECK: %2 = torch.operator "onnx.Or"(%0, %1) : (!torch.vtensor<[3,4,5],i1>, !torch.vtensor<[3,4,5],i1>) -> !torch.vtensor<[3,4,5],i1>
# CHECK: return %2 : !torch.vtensor<[3,4,5],i1>
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Test the expansion of ONNX operators that are functions, specifically the
# propagation of attribute values from the call-site to nodes within the
# expanded function.
#
# In this case, the model has a ReduceSumSquare node with the attribute
# 'keepdims' set to 0, and the definition of this version of ReduceSumSquare
# contains a ReduceSum node that references the value of 'keepdims', so we
# expect to see this value propagated to the ReduceSum node in the expansion.
#
# This also tests that the absence of 'axes' (as an optional attribute with no
# default value) is propagated in the same way.
#
# The model comes from an upstream ONNX test: backend/test/data/node/test_reduce_sum_square_do_not_keepdims_example/model.onnx

# RUN: %PYTHON -m torch_mlir.tools.import_onnx --disable-function-expansion-allowlist %s.onnx | FileCheck %s
#
# CHECK-LABEL: func.func @test_reduce_sum_square_do_not_keepdims_example
# CHECK: %0 = call @"{{.*}}ReduceSumSquare{{.*}}"
#
# CHECK-LABEL: func.func private @"{{.*}}ReduceSumSquare{{.*}}"
# CHECK: %0 = torch.operator "onnx.Mul"
# CHECK: %1 = torch.operator "onnx.ReduceSum"{{.*}}{torch.onnx.keepdims = 0 : si64}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Test the expansion of ONNX operators that are functions, specifically the
# propagation of attribute values from the call-site to nodes within the
# expanded function.
#
# In this case, the model has a ReduceSumSquare node with no attributes, but the
# definition of this version of ReduceSumSquare contains a ReduceSum node that
# references the value of 'keepdims', and the definition says its default value
# is 1, so we expect to see this value propagated to the ReduceSum node in the
# expansion.
#
# This also tests that the absence of 'axes' (as an optional attribute with no
# default value) is propagated in the same way.
#
# The model comes from an upstream ONNX test: backend/test/data/node/test_reduce_sum_square_empty_set/model.onnx

# RUN: %PYTHON -m torch_mlir.tools.import_onnx --disable-function-expansion-allowlist %s.onnx | FileCheck %s
#
# CHECK-LABEL: func.func @test_reduce_sum_square_empty_set
# CHECK: %0 = call @"{{.*}}ReduceSumSquare{{.*}}"
#
# CHECK-LABEL: func.func private @"{{.*}}ReduceSumSquare{{.*}}"
# CHECK: %0 = torch.operator "onnx.Mul"
# CHECK: %1 = torch.operator "onnx.ReduceSum"{{.*}}{torch.onnx.keepdims = 1 : si64}
Binary file not shown.

0 comments on commit 66920b8

Please sign in to comment.