-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Broadcast aten.maximum.default
and aten.minimum.default
inputs
#586
base: main
Are you sure you want to change the base?
Conversation
aten.maximum.default
inputs
dec0d74
to
a6822d3
Compare
aten.maximum.default
inputsaten.maximum.default
and aten.minimum.default
inputs
if len(args) > 1: | ||
other_tensor = args[1] | ||
else: | ||
other_tensor = kwargs["other"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if len(args) > 1: | |
other_tensor = args[1] | |
else: | |
other_tensor = kwargs["other"] | |
other_tensor = None # Explicitly initialize to a default value. | |
if len(args) > 1: | |
other_tensor = args[1] | |
else: | |
other_tensor = kwargs["other"] |
if new_shape is not None or new_dtype is not None: | ||
shape = new_shape if new_shape is not None else new_node.meta["val"].size() | ||
dtype = new_dtype if new_dtype is not None else new_node.meta["val"].dtype | ||
fake_mode = FakeTensorMode() | ||
fake_tensor = fake_mode.from_tensor(torch.zeros(shape, dtype=dtype)) | ||
new_node.meta["val"] = fake_tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you clarify the need for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
call_function is to create a new_node and is assigned meta from current_node which is being traversed, but new_node's shape & dtype may not same with cur_node (for example, new_node.target
is aten.expand
from current_node and then shape change), so there give the option for user to specify the correct shape & dtype
if input_tensor_shape == torch.Size([]): | ||
input_tensor_shape = torch.Size([1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you clarify the need for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
below code cannot handle []
, and the result of expand []
and [1]
is the same, so I see []
as [1]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hesitate to move forward with this.
For this one, I feel it is better to wait for the proper bcasting fix in TT-NN.
Change looks fairly intrusive to me
ok, then I cancel this PR and just wait tt-metal support tenstorrent/tt-metal#12852 |
Pull request was converted to draft
Ticket
#592
Problem description
aten.maximum
have some broadcasting issue of tenstorrent/tt-metal#12852 , I do the workaround by usingaten.expand
to broadcast its inputs beforehand, andaten.expand
may lowered to ttnn laterWhat's changed