diff --git a/iree/turbine/ops/iree.py b/iree/turbine/ops/iree.py index b4d79aee..9c4c28b3 100644 --- a/iree/turbine/ops/iree.py +++ b/iree/turbine/ops/iree.py @@ -6,13 +6,19 @@ """Custom ops for built-in IREE functionality.""" from typing import cast +import torch from ..support.ir_imports import ( Attribute, + ArrayAttr, + Block, + BlockArgument, + DictAttr, RankedTensorType, StringAttr, Value, flow_d, + func_d, tensor_d, ) @@ -66,7 +72,7 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder): # homogenous devices, there can easily be a "1", "2", etc. However, note that # there is nothing at this level that requires devices to be homogenous or # named in such a way. Internal to the module, this will require that a symbol -# with the name "__device.{moniker}" is provided in some fashion (spec file, +# with the name "__device_{moniker}" is provided in some fashion (spec file, # command line flags, etc). # # Within a graph, transfering tensors to a device causes partitioning and @@ -86,8 +92,8 @@ def select(self, ksel: KernelSelection): ta.specialize_all_dims() ksel.return_tensor(ta.t).specialize_all_dims() - def eager_execute(self, device_moniker, tensor): - return tensor + def eager_execute(self, device_moniker, tensor: torch.Tensor): + return tensor.clone() def generate(self, ksel: KernelSelection, kb: KernelBuilder): moniker = cast(AttrArg, ksel.arg_descs[0]).v @@ -99,11 +105,93 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder): kb.yield_results(result) +@CustomOp.register(library=IREE_LIBRARY) +class transfer_to_logical_device_(CustomOp): + """In-place variant of transfer_to_logical_device. + Rather than materializing the as an MLIR operation the corresponding block argument""" + + signature = "transfer_to_logical_device_(str moniker, Tensor(a!) tensor) -> ()" + + def select(self, ksel: KernelSelection): + ksel.attr_str(0) + ta = ksel.arg_tensor(1, inplace_tied=True) + ta.specialize_all_dims() + + def eager_execute(self, device_moniker, tensor): + pass + + def generate(self, ksel: KernelSelection, kb: KernelBuilder): + moniker = cast(AttrArg, ksel.arg_descs[0]).v + t = kb.arg_bindings[1] + block_arg_value = t + + # Find the corresponding block argument. + # This is brittle. + # We assume that each op in the use-def chain has 1 operand. + while not isinstance(block_arg_value.owner, Block): + assert len(block_arg_value.owner.operands) == 1 + block_arg_value = block_arg_value.owner.operands[0] + block = block_arg_value.owner + parent_op = block.region.owner + + # TODO: use FunctionOpInterface + assert isinstance(parent_op, func_d.FuncOp) + assert parent_op.body.blocks[0] == block + + for arg in block.arguments: + if arg == block_arg_value: + block_arg = arg + break + _set_func_op_argument_attribute( + parent_op, + arg_index=block_arg.arg_number, + attr_name="iree.abi.affinity", + attr=Attribute.parse(f'#hal.device.promise<@"__device_{moniker}">'), + ) + + kb.yield_results(t) + + ################################################################################ # Emission utilities ################################################################################ +def _set_dict_attr_value(dict_attr: DictAttr, key: str, value: Attribute) -> DictAttr: + d = {named_attr.name: named_attr.attr for named_attr in dict_attr} + d[key] = value + return DictAttr.get(d, dict_attr.context) + + +def _set_array_attr_value( + array_attr: ArrayAttr, index: int, value: Attribute +) -> ArrayAttr: + l = [v for v in array_attr] + l[index] = value + return ArrayAttr.get(l, context=array_attr.context) + + +def _set_dict_array_attr_value( + array_attr: ArrayAttr, index: int, key: str, value: Attribute +) -> ArrayAttr: + dictAttr = _set_dict_attr_value(array_attr[index], key, value) + return _set_array_attr_value(array_attr, index, dictAttr) + + +def _set_func_op_argument_attribute( + func_op: func_d.FuncOp, arg_index: int, attr_name: str, attr: Attribute +): + if "arg_attrs" not in func_op.attributes: + arg_attrs = ArrayAttr.get( + [DictAttr.get(context=func_op.context)] * len(func_op.arguments), + context=func_op.context, + ) + else: + arg_attrs = func_op.arg_attrs + arg_attrs = _set_dict_array_attr_value(arg_attrs, arg_index, attr_name, attr) + func_op.arg_attrs = arg_attrs + + def _append_dynamic_dims(kb: KernelBuilder, dynamic_dims: list[Value], tensor: Value): rtt = RankedTensorType(tensor.type) for i in range(rtt.rank): diff --git a/iree/turbine/support/ir_imports.py b/iree/turbine/support/ir_imports.py index 09aa4042..5a2d932f 100644 --- a/iree/turbine/support/ir_imports.py +++ b/iree/turbine/support/ir_imports.py @@ -10,11 +10,13 @@ from iree.compiler.ir import ( AsmState, Attribute, + ArrayAttr, Block, BlockArgument, Context, DenseElementsAttr, DenseResourceElementsAttr, + DictAttr, FlatSymbolRefAttr, FloatAttr, FunctionType, diff --git a/tests/ops/iree_test.py b/tests/ops/iree_test.py index facbf545..16c54372 100644 --- a/tests/ops/iree_test.py +++ b/tests/ops/iree_test.py @@ -51,6 +51,25 @@ def forward(self, x): asm, "flow.tensor.transfer %.+ to #hal.device.promise<@__device.1>" ) + def testEagerInPlace(self): + t1 = torch.randn(3, 4) + t2 = ops.iree.transfer_to_logical_device_("1", t1) + self.assertIs(None, t2) + + def testAotInPlace(self): + class MyModule(nn.Module): + def forward(self, x): + ops.iree.transfer_to_logical_device_("1", x) + x += 1 + return x + + cm = aot.export(MyModule(), args=(torch.empty(9, 8),)) + asm = str(cm.mlir_module) + self.assertRegex( + asm, + "@.+\(%.+: !torch.tensor<\[9,8\],f32> {iree.abi.affinity = #hal.device.promise<@__device_1>}", + ) + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG)