This repository has been archived by the owner on Aug 7, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 20
bring back torch.autograd.Function #316
Open
vkuzo
wants to merge
5
commits into
gh/vkuzo/29/base
Choose a base branch
from
gh/vkuzo/29/head
base: gh/vkuzo/29/base
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Commits on Jul 16, 2024
-
[TBD if for land] bring back torch.autograd.Function
Summary: This approach is more readable as we add additional scaling options. For now, seeing how many things break in 2024-07 with torch.autograd.Function + subclasses + compile. Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Configuration menu - View commit details
-
Copy full SHA for c1487ef - Browse repository at this point
Copy the full SHA c1487efView commit details -
Update on "[TBD if for land] bring back torch.autograd.Function"
Summary: This approach is more readable as we add additional scaling options. For now, seeing how many things break in 2024-07 with torch.autograd.Function + subclasses + compile. ``` # this code was resurrected from https://github.com/pytorch-labs/float8_experimental/pull/128/files # and modified to only support dynamic scaling # # Why do we want a torch.autograd.Function here? Vasiliy's opinion is that # as we add more scaling granularities, keeping the scaling code close to Float8Linear # will be really useful for readability and debuggability of numerics. # # For example, a future PR to add rowwise scaling could do # # # forward # x_bf16 = ... # if scaling_granularity == ScalingGranularity.PER_TENSOR: # # we can scale the same way for fwd/bwd # x_maybe_fp8 = to_fp8(...) # else: # assert scaling_granularity == ScalingGranularity.PER_ROW: # # defer scaling to float8_mm # x_maybe_fp8 = x_bf16 # # # repeat for w # # y_bf16 = float8_mm(x_maybe_fp8, w_maybe_fp8) # # Requirements for float8_mm # - composes with DTensor, compile, autograd # - readable/debuggable # # Option 1 (this PR): float8_mm is a torch.autograd.Function # - pros # - cons # Option 2 (current code without this PR): float8_mm is an override of torch.mm # - pros # - cons # ``` Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Configuration menu - View commit details
-
Copy full SHA for 8505776 - Browse repository at this point
Copy the full SHA 8505776View commit details
Commits on Jul 22, 2024
-
Update on "[TBD if for land] bring back torch.autograd.Function"
Summary: I want to plan for how we are going to add scaling granularities in the Python layer of float8 code. Today, we only have per-tensor scaling which is transposeable. For other types of scaling such as rowwise, the scaling is not transposeable and the user needs to choose what to do between fwd and bwd: a. keep the bf16 copy to be able to rescale across dim0 and dim1 b. scale bf16 across dim0/dim1, keep that, then requantize along the other dim in the bw (reduce memory usage, lose some precision) c. keep some of the gemms in bf16 to avoid the need to scale twice The modeling logic in Float8Linear for a/b would look like: ```python def forward(self, x): if scaling_type == TENSORWISE: x_maybe_fp8 = to_fp8_tensorwise(x, ...) elif scaling_type == ROWWISE: x_maybe_fp8 = to_fp8_rowwise(x, dim=0, ...) # repeat for w y = float8_mm_op(x_maybe_fp8, w_maybe_fp8, ...) ``` And, there are at least two choices I see for `float8_mm_op`: ```python # Option 1 (current code without this PR): use the torch.mm override implements([aten.mm.default, aten.matmul.default]) def float8_mm(aten_op, args, kwargs=None): ... # Option 2 (this PR): use torch.autograd.Function class float8_mm(torch.autograd.Function): ... ``` To support future scaling granularities, whichever choice we go with will have to do something like below: ```python def float8_mm(x_maybe_fp8, w_maybe_fp8): if isinstance(x_maybe_fp8, Float8Tensor): x_fp8 = x_maybe_fp8 else: x_fp8 = to_fp8(x_maybe_fp8, scaling_granularity, ...) # repeat for w # call torch._scaled_mm ``` Furthermore, to keep things readable / debuggable, it would be good to: 1. be able to print tensors before/after quantization 2. be able to associate tensors to their parent module, and the specific gemm in fwd/bwd in that module To do the above, we'll need to pass around metadata such as module FQNs. We should discuss whether we want Option 1 (keep overriding torch.mm) or Option 2 (torch.autograd.Function). Vasiliy: I think Option 2 is cleaner/more readable/more debuggable, modeling code is usually written in the module or similar torch.autograd.Function overrides. I would consider scaling tensors to float8 modeling code, and it's unintuitive IMO for this to happen deep inside op overrides. However, Option 1 is less risky technically as we avoid torch.autograd.Function which is less mature in interactions with torch.compile. While the current PR is all green, we are using `allow_in_graph` which is a bit unsafe. Test plan: ``` // all green ./test/test_everything.sh ``` [ghstack-poisoned]
Configuration menu - View commit details
-
Copy full SHA for a820346 - Browse repository at this point
Copy the full SHA a820346View commit details -
Update on "bring back torch.autograd.Function"
Summary: I want to plan for how we are going to add scaling granularities in the Python layer of float8 code. Today, we only have per-tensor scaling which is transposeable. For other types of scaling such as rowwise, the scaling is not transposeable and the user needs to choose what to do between fwd and bwd: a. keep the bf16 copy to be able to rescale across dim0 and dim1 b. scale bf16 across dim0/dim1, keep that, then requantize along the other dim in the bw (reduce memory usage, lose some precision) c. keep some of the gemms in bf16 to avoid the need to scale twice The modeling logic in Float8Linear for a/b would look like: ```python def forward(self, x): if scaling_type == TENSORWISE: x_maybe_fp8 = to_fp8_tensorwise(x, ...) elif scaling_type == ROWWISE: x_maybe_fp8 = to_fp8_rowwise(x, dim=0, ...) # repeat for w y = float8_mm_op(x_maybe_fp8, w_maybe_fp8, ...) ``` And, there are at least two choices I see for `float8_mm_op`: ```python # Option 1 (current code without this PR): use the torch.mm override implements([aten.mm.default, aten.matmul.default]) def float8_mm(aten_op, args, kwargs=None): ... # Option 2 (this PR): use torch.autograd.Function class float8_mm(torch.autograd.Function): ... ``` To support future scaling granularities, whichever choice we go with will have to do something like below: ```python def float8_mm(x_maybe_fp8, w_maybe_fp8): if isinstance(x_maybe_fp8, Float8Tensor): x_fp8 = x_maybe_fp8 else: x_fp8 = to_fp8(x_maybe_fp8, scaling_granularity, ...) # repeat for w # call torch._scaled_mm ``` Furthermore, to keep things readable / debuggable, it would be good to: 1. be able to print tensors before/after quantization 2. be able to associate tensors to their parent module, and the specific gemm in fwd/bwd in that module To do the above, we'll need to pass around metadata such as module FQNs. This PR implements Option 2 as IMO this is more readable/debuggable. Test plan: ``` // all green ./test/test_everything.sh ``` [ghstack-poisoned]
Configuration menu - View commit details
-
Copy full SHA for 47ef29d - Browse repository at this point
Copy the full SHA 47ef29dView commit details -
Update on "bring back torch.autograd.Function"
Summary: I want to plan for how we are going to add scaling granularities in the Python layer of float8 code. Today, we only have per-tensor scaling which is transposeable. For other types of scaling such as rowwise, the scaling is not transposeable and the user needs to choose what to do between fwd and bwd: a. keep the bf16 copy to be able to rescale across dim0 and dim1 b. scale bf16 across dim0/dim1, keep that, then requantize along the other dim in the bw (reduce memory usage, lose some precision) c. keep some of the gemms in bf16 to avoid the need to scale twice The modeling logic in Float8Linear for a/b would look like: ```python def forward(self, x): if scaling_type == TENSORWISE: x_maybe_fp8 = to_fp8_tensorwise(x, ...) elif scaling_type == ROWWISE: x_maybe_fp8 = to_fp8_rowwise(x, dim=0, ...) # repeat for w y = float8_mm_op(x_maybe_fp8, w_maybe_fp8, ...) ``` And, there are at least two choices I see for `float8_mm_op`: ```python # Option 1 (current code without this PR): use the torch.mm override implements([aten.mm.default, aten.matmul.default]) def float8_mm(aten_op, args, kwargs=None): ... # Option 2 (this PR): use torch.autograd.Function class float8_mm(torch.autograd.Function): ... ``` To support future scaling granularities, whichever choice we go with will have to do something like below: ```python def float8_mm(x_maybe_fp8, w_maybe_fp8): if isinstance(x_maybe_fp8, Float8Tensor): x_fp8 = x_maybe_fp8 else: x_fp8 = to_fp8(x_maybe_fp8, scaling_granularity, ...) # repeat for w # call torch._scaled_mm ``` Furthermore, to keep things readable / debuggable, it would be good to: 1. be able to print tensors before/after quantization 2. be able to associate tensors to their parent module, and the specific gemm in fwd/bwd in that module To do the above, we'll need to pass around metadata such as module FQNs. This PR implements Option 2 as IMO this is more readable/debuggable. Test plan: ``` // all green ./test/test_everything.sh ``` [ghstack-poisoned]
Configuration menu - View commit details
-
Copy full SHA for ec8829e - Browse repository at this point
Copy the full SHA ec8829eView commit details
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.