Skip to content

Commit

Permalink
Add pass for embedding simple metadata with iree.reflection attribute. (
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet authored Jun 26, 2024
1 parent 8ac7aa0 commit b489d8b
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 0 deletions.
59 changes: 59 additions & 0 deletions shark_turbine/transforms/general/add_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2024 Advanced Micro Devices, inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""
This pass will add a specified dictionary as an iree.reflection attribute to a module's public function(s).
"""

from typing import Callable, Dict, List, Optional, Tuple, Union

import re

from shark_turbine.support.ir_imports import *

from ..rewriter import *
from iree.compiler.ir import Context, DictAttr


def value_to_attr(value):
val_type = type(value).__name__
match val_type:
case "str":
return StringAttr.get(value)
case _:
return StringAttr.get(str(value))


class AddMetadataPass(Pass):
def __init__(
self,
mlir_module: Module,
inp_metadata: dict,
func_name: str,
):
super().__init__(mlir_module.operation)
self.mlir_module = mlir_module
self.inp_metadata = inp_metadata
self.func_name = func_name
self.context = self.mlir_module.context

def run(self):
def parse_metadata_dict(metadata_dict: dict) -> DictAttr:
with self.context:
for key, value in metadata_dict.items():
metadata_dict[key] = value_to_attr(value)
metadata_dict = DictAttr.get(metadata_dict)
return metadata_dict

metadata_dict_attr = parse_metadata_dict(self.inp_metadata)
for func_op in self.funcs:
if func_op.op.attributes[1].attr.value == self.func_name:
func_op.op.attributes["iree.reflection"] = metadata_dict_attr
return self.mlir_module


if __name__ == "__main__":
pass_main(AddMetadataPass)
47 changes: 47 additions & 0 deletions tests/transforms/general/add_metadata_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2024 Advanced Micro Devices, Inc
# Portions Copyright 2022 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from pathlib import Path
import logging
import unittest

from iree.compiler.ir import Context, Operation, Module

from shark_turbine.transforms.general import add_metadata

SIMPLE_FUNC_ASM = r"""
func.func @list_func(%arg0 : !iree_input.list<!iree_input.variant>) -> !iree_input.list<!iree_input.variant> {
return %arg0 : !iree_input.list<!iree_input.variant>
}
"""


class MetadataTest(unittest.TestCase):
def testBasic(self):
metadata_dict = {
"test_data_str": "test_data_str_value",
"test_data_int": 42,
"test_data_dict": {"test_data_dict_key": "test_data_dict_value"},
"test_data_list": ["test_data_list_value"],
"test_data_float": 3.14159,
"test_data_bool": True,
"test_data_tuple": (1,),
}
with Context() as context:
module = Module.parse(SIMPLE_FUNC_ASM)
module_op = add_metadata.AddMetadataPass(
module.operation,
metadata_dict,
"list_func",
).run()
module_asm = str(module_op)
print(module_asm)


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()

0 comments on commit b489d8b

Please sign in to comment.