Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
WIP: Create MLIR functions for ONNX operators that are functions
Resolves llvm#3384. Many ONNX operators are defined by functions and therefore could be expanded into simpler ONNX operations during importing, avoiding the need for tools downstream to support these operators directly. This commit changes onnx_importer.py to systematically perform this expansion for all ONNX operators that are not explicitly denylisted. When importing a node, the schema for the node's operation is retrieved. If the schema provides a function for the operator, a specialized version for the node's types and attributes will be created and imported as an MLIR function with private visibility. An MLIR function call will then be omitted, instead of a normal operator node. Caching is used to avoid generating redundant functions within the same module. Note that previously all MLIR functions generated by the importer had no visibility specified. This commit changes this: the main function for a model is now public. This is so that the MLIR inliner pass will automatically discard the (private) operator functions after inlining. Explanations for subtle code changes: - Looking up the correct schema and function for an operator requires knowing the opset version. NodeImporter retrieves this from the opset imports on the ModelProto retained by the GraphInfo. Previously, the model_proto field on GraphInfo was None when importing a subgraph in import_regions, but this conflicts with the new need for opset version info. Since the apparent purpose of setting it to None was to control how GraphInfo generates its input map, a new flag is added to GraphInfo (is_subgraph) to control this behavior, so that the actual ModelProto can now be provided without breaking this. - Some operators' functions are context-dependent, which means the function definition depends on the types of the inputs. Therefore node importing now needs to look up the types of a node's inputs, not just its outputs as was the case previously. Consequently find_type_proto_for_name() now gets called on graph inputs, not just intermediate values and graph outputs, so it has to be updated.
- Loading branch information