-
Notifications
You must be signed in to change notification settings - Fork 20
Add for option to use tensor hooks for Dynamic Linear #198
Conversation
@@ -14,3 +14,8 @@ | |||
# this doesn't work with autocast + torch.compile + FSDP. Enabling this | |||
# option is useful for safety, but not strictly necessary. | |||
enable_pre_and_post_forward = True | |||
|
|||
# If True, dynamic linear uses hooks for activation casting |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to figure out if we want this
@pytest.mark.parametrize( | ||
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32] | ||
) | ||
@pytest.mark.parametrize("use_activation_hooks", [True, False]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was some testing that was globbed together before this split the test into two
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!! I have only one comment as I thought we can avoid the backward prehook by just using a forward hook + the existing autograd function.
return module.cast_to_float8_e4m3fn(args[0]) | ||
|
||
|
||
def cast_dldy_to_float8_e5m2_backward_pre_hook(module, grad_output): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
testing my understanding:
It looks to me that the torch.compile issue is related to the backward_prehook and tensor subclass interactions.
But could this be solved by using module.register_forward_hook
and then inside the forward hook we call y = self.cast_to_float8_e5m2_bw(y)
just like the current casting?
This would solve the backward prehook issue I believe, but not sure if it would hit subclass issues (hopefully not)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ahhh great, idea unfortunately still erroring for compile..
I think that is a little cleaner, is that enough for a good interaction with DTensor?
I have this locally and can push up but don't know which one ultimately gets us closer, forward_hook
orfull_backward_pre_hook
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I think a forward_hook might be a little cleaner and easier for DTensor to interact. We would need to actually try composing those hooks to see if there's any gap. I feel let's try land the forward_hook approach in this PR, and if we found we need full_backward_pre_hook
instead of forward_hook
later, we can always change it back later if needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stamp to unblock, thanks for getting this to work!
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Summary
This is a duplicate of: #170
With more testing, ideally I think we wouldn't have the choice between hooks and modified forwards and just use hooks. However compile does not appear to support this yet