-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Broadcast aten.maximum.default
and aten.minimum.default
inputs
#586
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,6 +1,7 @@ | ||||||||||||||||||||
import torch | ||||||||||||||||||||
import ttnn | ||||||||||||||||||||
import math | ||||||||||||||||||||
import numpy as np | ||||||||||||||||||||
from torch._subclasses.fake_tensor import unset_fake_temporarily | ||||||||||||||||||||
from torch_ttnn.utils import ( | ||||||||||||||||||||
GraphCleanup, | ||||||||||||||||||||
|
@@ -17,6 +18,7 @@ | |||||||||||||||||||
|
||||||||||||||||||||
from torch.fx.passes.infra.pass_base import PassBase, PassResult | ||||||||||||||||||||
import torch.fx.traceback as fx_traceback | ||||||||||||||||||||
from torch._subclasses.fake_tensor import FakeTensorMode | ||||||||||||||||||||
from . import target_wrappers | ||||||||||||||||||||
from .to_tt_guard import can_lowering_to_ttnn | ||||||||||||||||||||
|
||||||||||||||||||||
|
@@ -444,13 +446,19 @@ def __init__(self, node): | |||||||||||||||||||
self.g = node.graph | ||||||||||||||||||||
self.node = node | ||||||||||||||||||||
|
||||||||||||||||||||
def call_function(self, target, args=(), kwargs={}): | ||||||||||||||||||||
def call_function(self, target, args=(), kwargs={}, new_shape=None, new_dtype=None): | ||||||||||||||||||||
new_node = self.g.call_function(target, args, kwargs) | ||||||||||||||||||||
new_node.meta = self.node.meta | ||||||||||||||||||||
new_node.meta = self.node.meta.copy() | ||||||||||||||||||||
if hasattr(self.node.target, "_schema"): | ||||||||||||||||||||
new_node.meta["original_input_variations"] = metrics.collect_input_variation_from_node(self.node) | ||||||||||||||||||||
if target == ttnn.layer_norm: | ||||||||||||||||||||
new_node.meta["val"] = new_node.meta["val"][0] | ||||||||||||||||||||
if new_shape is not None or new_dtype is not None: | ||||||||||||||||||||
shape = new_shape if new_shape is not None else new_node.meta["val"].size() | ||||||||||||||||||||
dtype = new_dtype if new_dtype is not None else new_node.meta["val"].dtype | ||||||||||||||||||||
fake_mode = FakeTensorMode() | ||||||||||||||||||||
fake_tensor = fake_mode.from_tensor(torch.zeros(shape, dtype=dtype)) | ||||||||||||||||||||
new_node.meta["val"] = fake_tensor | ||||||||||||||||||||
return new_node | ||||||||||||||||||||
|
||||||||||||||||||||
def inserting_before(self, node): | ||||||||||||||||||||
|
@@ -612,6 +620,8 @@ def batch_norm_inference(input, weight, bias, mean, var, momentum, eps): | |||||||||||||||||||
if not (hasattr(node, "meta") and "val" in node.meta and hasattr(node.meta["val"], "size")): | ||||||||||||||||||||
return None | ||||||||||||||||||||
input_tensor_shape = args[0].meta["val"].size() | ||||||||||||||||||||
if input_tensor_shape == torch.Size([]): | ||||||||||||||||||||
input_tensor_shape = torch.Size([1]) | ||||||||||||||||||||
Comment on lines
+623
to
+624
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you clarify the need for this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. below code cannot handle |
||||||||||||||||||||
output_shape = node.meta["val"].size() | ||||||||||||||||||||
if input_tensor_shape.numel() == output_shape.numel(): | ||||||||||||||||||||
if input_tensor_shape != output_shape: | ||||||||||||||||||||
|
@@ -1131,12 +1141,65 @@ def batch_norm_inference(input, weight, bias, mean, var, momentum, eps): | |||||||||||||||||||
return gm | ||||||||||||||||||||
|
||||||||||||||||||||
|
||||||||||||||||||||
def broadcast_tensors(g, tensors): | ||||||||||||||||||||
tensors_shapes = [get_shape(tensors[i]) for i in range(len(tensors))] | ||||||||||||||||||||
broadcasted_shape = torch.Size(np.broadcast_shapes(*tensors_shapes)) | ||||||||||||||||||||
broadcasted_tensors = [] | ||||||||||||||||||||
for i in range(len(tensors)): | ||||||||||||||||||||
if tensors_shapes[i] == broadcasted_shape: | ||||||||||||||||||||
broadcasted_tensors.append(tensors[i]) | ||||||||||||||||||||
else: | ||||||||||||||||||||
broadcasted_tensors.append( | ||||||||||||||||||||
g.call_function( | ||||||||||||||||||||
torch.ops.aten.expand.default, | ||||||||||||||||||||
(tensors[i], broadcasted_shape), | ||||||||||||||||||||
new_shape=broadcasted_shape, | ||||||||||||||||||||
new_dtype=tensors[i].meta["val"].dtype, | ||||||||||||||||||||
) | ||||||||||||||||||||
) | ||||||||||||||||||||
return broadcasted_shape, broadcasted_tensors | ||||||||||||||||||||
|
||||||||||||||||||||
|
||||||||||||||||||||
def DigestAtenOps(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: | ||||||||||||||||||||
nodes = list(gm.graph.nodes) | ||||||||||||||||||||
for node in nodes: | ||||||||||||||||||||
g = GraphWrapper(node) | ||||||||||||||||||||
|
||||||||||||||||||||
def rewrite_node(node): | ||||||||||||||||||||
args = node.args | ||||||||||||||||||||
kwargs = node.kwargs | ||||||||||||||||||||
|
||||||||||||||||||||
# workaround for issue #64 | ||||||||||||||||||||
if node.target in [torch.ops.aten.maximum.default, torch.ops.aten.minimum.default]: | ||||||||||||||||||||
self_tensor = args[0] | ||||||||||||||||||||
if len(args) > 1: | ||||||||||||||||||||
other_tensor = args[1] | ||||||||||||||||||||
else: | ||||||||||||||||||||
other_tensor = kwargs["other"] | ||||||||||||||||||||
Comment on lines
+1175
to
+1178
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||
if get_shape(self_tensor) is None or get_shape(other_tensor) is None: | ||||||||||||||||||||
return None | ||||||||||||||||||||
broadcasted_shape, broadcasted_tensors = broadcast_tensors(g, [self_tensor, other_tensor]) | ||||||||||||||||||||
return g.call_function(node.target, tuple(broadcasted_tensors)) | ||||||||||||||||||||
|
||||||||||||||||||||
with g.inserting_before(node): | ||||||||||||||||||||
new_node = rewrite_node(node) | ||||||||||||||||||||
if new_node is not None: | ||||||||||||||||||||
node.replace_all_uses_with( | ||||||||||||||||||||
new_node, | ||||||||||||||||||||
delete_user_cb=lambda node: node != new_node, | ||||||||||||||||||||
) | ||||||||||||||||||||
|
||||||||||||||||||||
gm = GraphCleanup(gm) | ||||||||||||||||||||
return gm | ||||||||||||||||||||
|
||||||||||||||||||||
|
||||||||||||||||||||
class ToTtPass(PassBase): | ||||||||||||||||||||
def __init__(self, device, use_less_ttnn_op_types): | ||||||||||||||||||||
self.device = device | ||||||||||||||||||||
self.use_less_ttnn_op_types = use_less_ttnn_op_types | ||||||||||||||||||||
|
||||||||||||||||||||
def call(self, gm: torch.fx.GraphModule): | ||||||||||||||||||||
gm = DigestAtenOps(gm) | ||||||||||||||||||||
# Replace more patterns with torch.fx.Transformer | ||||||||||||||||||||
gm = ReplaceMoreTt(gm, self.device, self.use_less_ttnn_op_types).transform() | ||||||||||||||||||||
|
||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you clarify the need for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
call_function is to create a new_node and is assigned meta from current_node which is being traversed, but new_node's shape & dtype may not same with cur_node (for example,
new_node.target
isaten.expand
from current_node and then shape change), so there give the option for user to specify the correct shape & dtype