Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
Update base for Update on "bring back torch.autograd.Function for flo…
Browse files Browse the repository at this point in the history
…at8 matmul"

Summary:

This is a redo of
#316

With upcoming support of scaling granularities other than tensorwise,
we need a good way to control which gemm kernel to call and how to scale
the input tensors in fwd and bwd. A `torch.autograd.Function` override
is the cleanest way to do that, and in 2024 this now works with
`torch.compile`.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D60252068](https://our.internmc.facebook.com/intern/diff/D60252068)

[ghstack-poisoned]
  • Loading branch information
vkuzo committed Jul 26, 2024
2 parents 224cfdf + 0aca10a commit bc4438e
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions float8_experimental/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,16 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:

# inf-norm is equivalent to max(abs(w))
max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial
amax_tensor = torch.vstack(max_weights) # Partial
amax_tensor = torch.stack(max_weights) # Partial
# clamp is dispatched through DTensor
# it will issue a single all-reduce
amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate
scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate
if amax_tensor.dtype is torch.float16:
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max)
scales = torch.split(scale_tensor, 1) # Replicate
for scale, float8_linear in zip(scales, float8_linears):
float8_linear.weight._local_tensor._precomputed_scale = (
scale._local_tensor.squeeze()
)
local_scale_tensor = scale_tensor.to_local()
for i, float8_linear in enumerate(float8_linears):
float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i]


# FSDP pads its local tensor on dim-0. The subclass should be preserved such
Expand Down

0 comments on commit bc4438e

Please sign in to comment.