From cce8552d5e36c40fa8ff90af2d87d3b6df12e45e Mon Sep 17 00:00:00 2001 From: swimdi Date: Wed, 11 Dec 2024 16:48:33 +0800 Subject: [PATCH 1/2] Broadcast maximum input, remove aten_maximum_default_blocklist --- tests/lowering/eltwise/binary/test_maximum.py | 16 ++--- torch_ttnn/passes/lowering/to_tt_guard.py | 3 +- .../passes/lowering/to_tt_guard_autogen.py | 6 -- torch_ttnn/passes/lowering/to_tt_pass.py | 67 ++++++++++++++++++- 4 files changed, 71 insertions(+), 21 deletions(-) diff --git a/tests/lowering/eltwise/binary/test_maximum.py b/tests/lowering/eltwise/binary/test_maximum.py index cc037221e..eb3e7e1f6 100644 --- a/tests/lowering/eltwise/binary/test_maximum.py +++ b/tests/lowering/eltwise/binary/test_maximum.py @@ -16,18 +16,10 @@ def forward(self, x, y): "input_shapes", ( ((32, 32), (32, 32)), - pytest.param( - ((64,), (32, 64)), - marks=pytest.mark.xfail(reason="broadcasting issues (#64)"), - ), - pytest.param( - ((64, 32), (64, 1)), - marks=pytest.mark.xfail(reason="broadcasting issues (#64)"), - ), - pytest.param( - ((64, 1), (1, 64)), - marks=pytest.mark.xfail(reason="broadcasting issues (#64)"), - ), + ((64,), (32, 64)), + ((64, 32), (64, 1)), + ((64, 1), (1, 64)), + ((1, 16, 59, 59), ()), ), ) def test_maximum(device, input_shapes): diff --git a/torch_ttnn/passes/lowering/to_tt_guard.py b/torch_ttnn/passes/lowering/to_tt_guard.py index 1f7230eca..f8efd9397 100644 --- a/torch_ttnn/passes/lowering/to_tt_guard.py +++ b/torch_ttnn/passes/lowering/to_tt_guard.py @@ -182,7 +182,7 @@ # ttnn.from_torch not support scalar # RuntimeError: TT_FATAL @ tensor/types.cpp:209: normalized_index >= 0 and normalized_index < rank # not lowering ttnn.maximum to avoid ttnn.from_torch of scalar -aten_maximum_default_blocklist += [["Tensor<[1, 16, 59, 59]> self = ?", "Tensor other = ?"]] +aten_maximum_default_blocklist = [["Tensor<[1, 16, 59, 59]> self = ?", "Tensor other = ?"]] # torch._dynamo.exc.BackendCompilerFailed: backend='ttnn_backend' raised: # RuntimeError: aten::clone() Expected a value of type 'Tensor' for argument 'self' but instead found type 'SymInt'. @@ -318,6 +318,7 @@ GUARD[torch.ops.aten.gt.Scalar] = partial(guard_aten, aten_gt_Scalar_blocklist) GUARD[torch.ops.aten.unsqueeze.default] = partial(guard_aten, aten_unsqueeze_default_blocklist) GUARD[torch.ops.aten.cumsum.default] = partial(guard_aten, aten_cumsum_default_blocklist) +GUARD[torch.ops.aten.maximum.default] = partial(guard_aten, aten_maximum_default_blocklist) def can_lowering_to_ttnn(node): diff --git a/torch_ttnn/passes/lowering/to_tt_guard_autogen.py b/torch_ttnn/passes/lowering/to_tt_guard_autogen.py index 725a34b06..83110763d 100644 --- a/torch_ttnn/passes/lowering/to_tt_guard_autogen.py +++ b/torch_ttnn/passes/lowering/to_tt_guard_autogen.py @@ -56,11 +56,6 @@ ["Tensor<[16, 1, 1]> self = ?", "Optional[number] min = ?", "Optional[number] max = 4.605170185988092"], ["Tensor<[32, 1, 1]> self = ?", "Optional[number] min = ?", "Optional[number] max = 4.605170185988092"], ] -aten_maximum_default_blocklist = [ - ["Tensor<[1, 16, 19, 19]> self = ?", "Tensor other = ?"], - ["Tensor<[1, 16, 59, 59]> self = ?", "Tensor<[]> other = ?"], - ["Tensor<[1, 16, 1, 60]> self = ?", "Tensor<[]> other = ?"], -] aten__log_softmax_default_blocklist = [["Tensor<[19, 256008]> self = ?", "int dim = 1", "bool half_to_float = False"]] aten_full_default_blocklist = [ [ @@ -1498,7 +1493,6 @@ def guard_aten(blocklist, node): GUARD = { torch.ops.aten.clamp.default: partial(guard_aten, aten_clamp_default_blocklist), - torch.ops.aten.maximum.default: partial(guard_aten, aten_maximum_default_blocklist), torch.ops.aten._log_softmax.default: partial(guard_aten, aten__log_softmax_default_blocklist), torch.ops.aten.full.default: partial(guard_aten, aten_full_default_blocklist), torch.ops.aten._scaled_dot_product_flash_attention.default: partial( diff --git a/torch_ttnn/passes/lowering/to_tt_pass.py b/torch_ttnn/passes/lowering/to_tt_pass.py index 54185a928..47c6e1f40 100644 --- a/torch_ttnn/passes/lowering/to_tt_pass.py +++ b/torch_ttnn/passes/lowering/to_tt_pass.py @@ -1,6 +1,7 @@ import torch import ttnn import math +import numpy as np from torch._subclasses.fake_tensor import unset_fake_temporarily from torch_ttnn.utils import ( GraphCleanup, @@ -17,6 +18,7 @@ from torch.fx.passes.infra.pass_base import PassBase, PassResult import torch.fx.traceback as fx_traceback +from torch._subclasses.fake_tensor import FakeTensorMode from . import target_wrappers from .to_tt_guard import can_lowering_to_ttnn @@ -444,13 +446,19 @@ def __init__(self, node): self.g = node.graph self.node = node - def call_function(self, target, args=(), kwargs={}): + def call_function(self, target, args=(), kwargs={}, new_shape=None, new_dtype=None): new_node = self.g.call_function(target, args, kwargs) - new_node.meta = self.node.meta + new_node.meta = self.node.meta.copy() if hasattr(self.node.target, "_schema"): new_node.meta["original_input_variations"] = metrics.collect_input_variation_from_node(self.node) if target == ttnn.layer_norm: new_node.meta["val"] = new_node.meta["val"][0] + if new_shape is not None or new_dtype is not None: + shape = new_shape if new_shape is not None else new_node.meta["val"].size() + dtype = new_dtype if new_dtype is not None else new_node.meta["val"].dtype + fake_mode = FakeTensorMode() + fake_tensor = fake_mode.from_tensor(torch.zeros(shape, dtype=dtype)) + new_node.meta["val"] = fake_tensor return new_node def inserting_before(self, node): @@ -612,6 +620,8 @@ def batch_norm_inference(input, weight, bias, mean, var, momentum, eps): if not (hasattr(node, "meta") and "val" in node.meta and hasattr(node.meta["val"], "size")): return None input_tensor_shape = args[0].meta["val"].size() + if input_tensor_shape == torch.Size([]): + input_tensor_shape = torch.Size([1]) output_shape = node.meta["val"].size() if input_tensor_shape.numel() == output_shape.numel(): if input_tensor_shape != output_shape: @@ -1131,12 +1141,65 @@ def batch_norm_inference(input, weight, bias, mean, var, momentum, eps): return gm +def broadcast_tensors(g, tensors): + tensors_shapes = [get_shape(tensors[i]) for i in range(len(tensors))] + broadcasted_shape = torch.Size(np.broadcast_shapes(*tensors_shapes)) + broadcasted_tensors = [] + for i in range(len(tensors)): + if tensors_shapes[i] == broadcasted_shape: + broadcasted_tensors.append(tensors[i]) + else: + broadcasted_tensors.append( + g.call_function( + torch.ops.aten.expand.default, + (tensors[i], broadcasted_shape), + new_shape=broadcasted_shape, + new_dtype=tensors[i].meta["val"].dtype, + ) + ) + return broadcasted_shape, broadcasted_tensors + + +def DigestAtenOps(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + nodes = list(gm.graph.nodes) + for node in nodes: + g = GraphWrapper(node) + + def rewrite_node(node): + args = node.args + kwargs = node.kwargs + + # workaround for issue #64 + if node.target == torch.ops.aten.maximum.default: + self_tensor = args[0] + if len(args) > 1: + other_tensor = args[1] + else: + other_tensor = kwargs["other"] + if get_shape(self_tensor) is None or get_shape(other_tensor) is None: + return None + broadcasted_shape, broadcasted_tensors = broadcast_tensors(g, [self_tensor, other_tensor]) + return g.call_function(torch.ops.aten.maximum.default, tuple(broadcasted_tensors)) + + with g.inserting_before(node): + new_node = rewrite_node(node) + if new_node is not None: + node.replace_all_uses_with( + new_node, + delete_user_cb=lambda node: node != new_node, + ) + + gm = GraphCleanup(gm) + return gm + + class ToTtPass(PassBase): def __init__(self, device, use_less_ttnn_op_types): self.device = device self.use_less_ttnn_op_types = use_less_ttnn_op_types def call(self, gm: torch.fx.GraphModule): + gm = DigestAtenOps(gm) # Replace more patterns with torch.fx.Transformer gm = ReplaceMoreTt(gm, self.device, self.use_less_ttnn_op_types).transform() From a6822d376d098d4ecc7a1e8dce8d1b1828efc798 Mon Sep 17 00:00:00 2001 From: swimdi Date: Wed, 11 Dec 2024 18:05:24 +0800 Subject: [PATCH 2/2] Broadcast minimum input --- tests/lowering/eltwise/binary/test_minimum.py | 16 ++++------------ torch_ttnn/passes/lowering/to_tt_pass.py | 4 ++-- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/tests/lowering/eltwise/binary/test_minimum.py b/tests/lowering/eltwise/binary/test_minimum.py index e577e46a0..6174668f4 100644 --- a/tests/lowering/eltwise/binary/test_minimum.py +++ b/tests/lowering/eltwise/binary/test_minimum.py @@ -16,18 +16,10 @@ def forward(self, x, y): "input_shapes", ( ((32, 32), (32, 32)), - pytest.param( - ((64,), (32, 64)), - marks=pytest.mark.xfail(reason="broadcasting issues (#64)"), - ), - pytest.param( - ((64, 32), (64, 1)), - marks=pytest.mark.xfail(reason="broadcasting issues (#64)"), - ), - pytest.param( - ((64, 1), (1, 64)), - marks=pytest.mark.xfail(reason="broadcasting issues (#64)"), - ), + ((64,), (32, 64)), + ((64, 32), (64, 1)), + ((64, 1), (1, 64)), + ((1, 16, 59, 59), ()), ), ) def test_minimum(device, input_shapes): diff --git a/torch_ttnn/passes/lowering/to_tt_pass.py b/torch_ttnn/passes/lowering/to_tt_pass.py index 47c6e1f40..41b5c13fa 100644 --- a/torch_ttnn/passes/lowering/to_tt_pass.py +++ b/torch_ttnn/passes/lowering/to_tt_pass.py @@ -1170,7 +1170,7 @@ def rewrite_node(node): kwargs = node.kwargs # workaround for issue #64 - if node.target == torch.ops.aten.maximum.default: + if node.target in [torch.ops.aten.maximum.default, torch.ops.aten.minimum.default]: self_tensor = args[0] if len(args) > 1: other_tensor = args[1] @@ -1179,7 +1179,7 @@ def rewrite_node(node): if get_shape(self_tensor) is None or get_shape(other_tensor) is None: return None broadcasted_shape, broadcasted_tensors = broadcast_tensors(g, [self_tensor, other_tensor]) - return g.call_function(torch.ops.aten.maximum.default, tuple(broadcasted_tensors)) + return g.call_function(node.target, tuple(broadcasted_tensors)) with g.inserting_before(node): new_node = rewrite_node(node)