-
Notifications
You must be signed in to change notification settings - Fork 20
[wip] add option to do activation/grad cast from hooks #170
base: main
Are you sure you want to change the base?
Conversation
3fe1055
to
855795c
Compare
""" | ||
Hook to cast the incoming gradient to `torch.float8_e5m2` | ||
""" | ||
new_output = NoopFwToFloat8E5M2Bw.apply(output, module.emulate) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit surprised that torch.compile does not support this case with autograd function, as both DTensor pre-forward hook and forward hook are using autograd functions, probably because DTensor is using sth like allow_in_graph
in dynamo. Maybe one way to workaround is to use a function that calls NoopFwToFloat8E5M2Bw.apply(output, module.emulate)
inside, and then make it a allow_in_graph in dynamo?
note: backward pre hook works with pytorch/pytorch#116454, we still see the same dynamo error though (and not using torch.autograd.Function anymore) |
270ce64
to
f00cac9
Compare
@wanchaol , looks like the current dynamo issue is lack of support of Would you expect |
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
f00cac9
to
f83bf21
Compare
Summary: This is a duplicate of: #170 With more testing, ideally I think we wouldn't have the choice between hooks and modified forwards and just use hooks. However compile does not appear to support this yet Pull Request resolved: #198 Reviewed By: wanchaol Differential Revision: D53287660 Pulled By: drisspg fbshipit-source-id: 727e43e8850f3a480ba87df80c0710516ef45f28
Summary:
Testing moving activation casting logic into hooks, so we can start building towards composability of Float8 with DTensor
Note: needs pytorch/pytorch#116454 to land to enable backward pre hook in all cases
Current status:
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags: