Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Broadcast aten.maximum.default and aten.minimum.default inputs #586

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions tests/lowering/eltwise/binary/test_maximum.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,10 @@ def forward(self, x, y):
"input_shapes",
(
((32, 32), (32, 32)),
pytest.param(
((64,), (32, 64)),
marks=pytest.mark.xfail(reason="broadcasting issues (#64)"),
),
pytest.param(
((64, 32), (64, 1)),
marks=pytest.mark.xfail(reason="broadcasting issues (#64)"),
),
pytest.param(
((64, 1), (1, 64)),
marks=pytest.mark.xfail(reason="broadcasting issues (#64)"),
),
((64,), (32, 64)),
((64, 32), (64, 1)),
((64, 1), (1, 64)),
((1, 16, 59, 59), ()),
),
)
def test_maximum(device, input_shapes):
Expand Down
16 changes: 4 additions & 12 deletions tests/lowering/eltwise/binary/test_minimum.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,10 @@ def forward(self, x, y):
"input_shapes",
(
((32, 32), (32, 32)),
pytest.param(
((64,), (32, 64)),
marks=pytest.mark.xfail(reason="broadcasting issues (#64)"),
),
pytest.param(
((64, 32), (64, 1)),
marks=pytest.mark.xfail(reason="broadcasting issues (#64)"),
),
pytest.param(
((64, 1), (1, 64)),
marks=pytest.mark.xfail(reason="broadcasting issues (#64)"),
),
((64,), (32, 64)),
((64, 32), (64, 1)),
((64, 1), (1, 64)),
((1, 16, 59, 59), ()),
),
)
def test_minimum(device, input_shapes):
Expand Down
3 changes: 2 additions & 1 deletion torch_ttnn/passes/lowering/to_tt_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@
# ttnn.from_torch not support scalar
# RuntimeError: TT_FATAL @ tensor/types.cpp:209: normalized_index >= 0 and normalized_index < rank
# not lowering ttnn.maximum to avoid ttnn.from_torch of scalar
aten_maximum_default_blocklist += [["Tensor<[1, 16, 59, 59]> self = ?", "Tensor other = ?"]]
aten_maximum_default_blocklist = [["Tensor<[1, 16, 59, 59]> self = ?", "Tensor other = ?"]]

# torch._dynamo.exc.BackendCompilerFailed: backend='ttnn_backend' raised:
# RuntimeError: aten::clone() Expected a value of type 'Tensor' for argument 'self' but instead found type 'SymInt'.
Expand Down Expand Up @@ -318,6 +318,7 @@
GUARD[torch.ops.aten.gt.Scalar] = partial(guard_aten, aten_gt_Scalar_blocklist)
GUARD[torch.ops.aten.unsqueeze.default] = partial(guard_aten, aten_unsqueeze_default_blocklist)
GUARD[torch.ops.aten.cumsum.default] = partial(guard_aten, aten_cumsum_default_blocklist)
GUARD[torch.ops.aten.maximum.default] = partial(guard_aten, aten_maximum_default_blocklist)


def can_lowering_to_ttnn(node):
Expand Down
6 changes: 0 additions & 6 deletions torch_ttnn/passes/lowering/to_tt_guard_autogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,6 @@
["Tensor<[16, 1, 1]> self = ?", "Optional[number] min = ?", "Optional[number] max = 4.605170185988092"],
["Tensor<[32, 1, 1]> self = ?", "Optional[number] min = ?", "Optional[number] max = 4.605170185988092"],
]
aten_maximum_default_blocklist = [
["Tensor<[1, 16, 19, 19]> self = ?", "Tensor other = ?"],
["Tensor<[1, 16, 59, 59]> self = ?", "Tensor<[]> other = ?"],
["Tensor<[1, 16, 1, 60]> self = ?", "Tensor<[]> other = ?"],
]
aten__log_softmax_default_blocklist = [["Tensor<[19, 256008]> self = ?", "int dim = 1", "bool half_to_float = False"]]
aten_full_default_blocklist = [
[
Expand Down Expand Up @@ -1498,7 +1493,6 @@ def guard_aten(blocklist, node):

GUARD = {
torch.ops.aten.clamp.default: partial(guard_aten, aten_clamp_default_blocklist),
torch.ops.aten.maximum.default: partial(guard_aten, aten_maximum_default_blocklist),
torch.ops.aten._log_softmax.default: partial(guard_aten, aten__log_softmax_default_blocklist),
torch.ops.aten.full.default: partial(guard_aten, aten_full_default_blocklist),
torch.ops.aten._scaled_dot_product_flash_attention.default: partial(
Expand Down
67 changes: 65 additions & 2 deletions torch_ttnn/passes/lowering/to_tt_pass.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import ttnn
import math
import numpy as np
from torch._subclasses.fake_tensor import unset_fake_temporarily
from torch_ttnn.utils import (
GraphCleanup,
Expand All @@ -17,6 +18,7 @@

from torch.fx.passes.infra.pass_base import PassBase, PassResult
import torch.fx.traceback as fx_traceback
from torch._subclasses.fake_tensor import FakeTensorMode
from . import target_wrappers
from .to_tt_guard import can_lowering_to_ttnn

Expand Down Expand Up @@ -444,13 +446,19 @@ def __init__(self, node):
self.g = node.graph
self.node = node

def call_function(self, target, args=(), kwargs={}):
def call_function(self, target, args=(), kwargs={}, new_shape=None, new_dtype=None):
new_node = self.g.call_function(target, args, kwargs)
new_node.meta = self.node.meta
new_node.meta = self.node.meta.copy()
if hasattr(self.node.target, "_schema"):
new_node.meta["original_input_variations"] = metrics.collect_input_variation_from_node(self.node)
if target == ttnn.layer_norm:
new_node.meta["val"] = new_node.meta["val"][0]
if new_shape is not None or new_dtype is not None:
shape = new_shape if new_shape is not None else new_node.meta["val"].size()
dtype = new_dtype if new_dtype is not None else new_node.meta["val"].dtype
fake_mode = FakeTensorMode()
fake_tensor = fake_mode.from_tensor(torch.zeros(shape, dtype=dtype))
new_node.meta["val"] = fake_tensor
Comment on lines +456 to +461
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you clarify the need for this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

call_function is to create a new_node and is assigned meta from current_node which is being traversed, but new_node's shape & dtype may not same with cur_node (for example, new_node.target is aten.expand from current_node and then shape change), so there give the option for user to specify the correct shape & dtype

return new_node

def inserting_before(self, node):
Expand Down Expand Up @@ -612,6 +620,8 @@ def batch_norm_inference(input, weight, bias, mean, var, momentum, eps):
if not (hasattr(node, "meta") and "val" in node.meta and hasattr(node.meta["val"], "size")):
return None
input_tensor_shape = args[0].meta["val"].size()
if input_tensor_shape == torch.Size([]):
input_tensor_shape = torch.Size([1])
Comment on lines +623 to +624
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you clarify the need for this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

below code cannot handle [], and the result of expand [] and [1] is the same, so I see [] as [1]

output_shape = node.meta["val"].size()
if input_tensor_shape.numel() == output_shape.numel():
if input_tensor_shape != output_shape:
Expand Down Expand Up @@ -1131,12 +1141,65 @@ def batch_norm_inference(input, weight, bias, mean, var, momentum, eps):
return gm


def broadcast_tensors(g, tensors):
tensors_shapes = [get_shape(tensors[i]) for i in range(len(tensors))]
broadcasted_shape = torch.Size(np.broadcast_shapes(*tensors_shapes))
broadcasted_tensors = []
for i in range(len(tensors)):
if tensors_shapes[i] == broadcasted_shape:
broadcasted_tensors.append(tensors[i])
else:
broadcasted_tensors.append(
g.call_function(
torch.ops.aten.expand.default,
(tensors[i], broadcasted_shape),
new_shape=broadcasted_shape,
new_dtype=tensors[i].meta["val"].dtype,
)
)
return broadcasted_shape, broadcasted_tensors


def DigestAtenOps(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
nodes = list(gm.graph.nodes)
for node in nodes:
g = GraphWrapper(node)

def rewrite_node(node):
args = node.args
kwargs = node.kwargs

# workaround for issue #64
if node.target in [torch.ops.aten.maximum.default, torch.ops.aten.minimum.default]:
self_tensor = args[0]
if len(args) > 1:
other_tensor = args[1]
else:
other_tensor = kwargs["other"]
Comment on lines +1175 to +1178
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if len(args) > 1:
other_tensor = args[1]
else:
other_tensor = kwargs["other"]
other_tensor = None # Explicitly initialize to a default value.
if len(args) > 1:
other_tensor = args[1]
else:
other_tensor = kwargs["other"]

if get_shape(self_tensor) is None or get_shape(other_tensor) is None:
return None
broadcasted_shape, broadcasted_tensors = broadcast_tensors(g, [self_tensor, other_tensor])
return g.call_function(node.target, tuple(broadcasted_tensors))

with g.inserting_before(node):
new_node = rewrite_node(node)
if new_node is not None:
node.replace_all_uses_with(
new_node,
delete_user_cb=lambda node: node != new_node,
)

gm = GraphCleanup(gm)
return gm


class ToTtPass(PassBase):
def __init__(self, device, use_less_ttnn_op_types):
self.device = device
self.use_less_ttnn_op_types = use_less_ttnn_op_types

def call(self, gm: torch.fx.GraphModule):
gm = DigestAtenOps(gm)
# Replace more patterns with torch.fx.Transformer
gm = ReplaceMoreTt(gm, self.device, self.use_less_ttnn_op_types).transform()

Expand Down
Loading