Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Add for option to use tensor hooks for Dynamic Linear #198

Closed
wants to merge 3 commits into from

Conversation

drisspg
Copy link
Contributor

@drisspg drisspg commented Jan 30, 2024

Summary

This is a duplicate of: #170
With more testing, ideally I think we wouldn't have the choice between hooks and modified forwards and just use hooks. However compile does not appear to support this yet

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 30, 2024
@@ -14,3 +14,8 @@
# this doesn't work with autocast + torch.compile + FSDP. Enabling this
# option is useful for safety, but not strictly necessary.
enable_pre_and_post_forward = True

# If True, dynamic linear uses hooks for activation casting
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to figure out if we want this

@pytest.mark.parametrize(
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
)
@pytest.mark.parametrize("use_activation_hooks", [True, False])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was some testing that was globbed together before this split the test into two

@drisspg
Copy link
Contributor Author

drisspg commented Jan 30, 2024

@bdhirsh , @wanchaol This is a dupe and some on #170. Module hooks indeed fail with compile today for the same, multiple fake modes in env

@drisspg drisspg mentioned this pull request Jan 30, 2024
1 task
Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!! I have only one comment as I thought we can avoid the backward prehook by just using a forward hook + the existing autograd function.

return module.cast_to_float8_e4m3fn(args[0])


def cast_dldy_to_float8_e5m2_backward_pre_hook(module, grad_output):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

testing my understanding:

It looks to me that the torch.compile issue is related to the backward_prehook and tensor subclass interactions.

But could this be solved by using module.register_forward_hook and then inside the forward hook we call y = self.cast_to_float8_e5m2_bw(y) just like the current casting?

This would solve the backward prehook issue I believe, but not sure if it would hit subclass issues (hopefully not)

Copy link
Contributor Author

@drisspg drisspg Jan 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahhh great, idea unfortunately still erroring for compile..

I think that is a little cleaner, is that enough for a good interaction with DTensor?

I have this locally and can push up but don't know which one ultimately gets us closer, forward_hook orfull_backward_pre_hook

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think a forward_hook might be a little cleaner and easier for DTensor to interact. We would need to actually try composing those hooks to see if there's any gap. I feel let's try land the forward_hook approach in this PR, and if we found we need full_backward_pre_hook instead of forward_hook later, we can always change it back later if needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good!

Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stamp to unblock, thanks for getting this to work!

@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@drisspg merged this pull request in 6df3e55.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants