Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
bring back torch.autograd.Function for float8 matmul
Browse files Browse the repository at this point in the history
Summary:

This is a redo of
#316

With upcoming support of scaling granularities other than tensorwise,
we need a good way to control which gemm kernel to call and how to scale
the input tensors in fwd and bwd. A `torch.autograd.Function` override
is the cleanest way to do that, and in 2024 this now works with
`torch.compile`.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 42dd59511e4ec2a55846c2593955c4ff5f12b254
Pull Request resolved: #336
  • Loading branch information
vkuzo committed Jul 25, 2024
1 parent 013ac0c commit 113dbdf
Showing 1 changed file with 57 additions and 1 deletion.
58 changes: 57 additions & 1 deletion float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,62 @@ def _maybe_initialize_amaxes_scales_for_float8_cast(
scale.copy_(new_scale)


# this code was resurrected from https://github.com/pytorch-labs/float8_experimental/pull/128/files
@torch._dynamo.allow_in_graph
class manual_float8_matmul(torch.autograd.Function):
"""
Like torch.matmul, but with the arguments in float8
"""

@staticmethod
def forward(
ctx,
input_fp8,
weight_fp8_t,
):
ctx.save_for_backward(input_fp8, weight_fp8_t)
# the reshapes are needed in order to make the shapes compatible with
# torch.mm
orig_shape = input_fp8.shape
input_fp8_reshaped = input_fp8.reshape(-1, orig_shape[-1])
res_bits = torch.mm(input_fp8_reshaped, weight_fp8_t)
res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1])
return res_bits

@staticmethod
def backward(ctx, grad_output_fp8):
input_fp8, weight_fp8_t = ctx.saved_tensors

# the reshapes are needed in order to make the shapes compatible with
# torch.mm
grad_output_fp8_orig_shape = grad_output_fp8.shape
grad_output_fp8_reshaped = grad_output_fp8.reshape(
-1, grad_output_fp8_orig_shape[-1]
)

# calculate grad_input
grad_input = torch.mm(
grad_output_fp8_reshaped,
weight_fp8_t.t(),
)
grad_input = grad_input.reshape(
*grad_output_fp8_orig_shape[:-1], grad_input.shape[-1]
)

input_fp8_orig_shape = input_fp8.shape
input_fp8_reshaped = input_fp8.reshape(-1, input_fp8_orig_shape[-1])

# calculate grad_weight
# Note: the variant below is slightly faster on LLaMa 3 8B pretraining
# compared to than calculating `grad_weight_t = input_fp8_t @ grad_output_fp8_reshaped`
grad_weight = torch.mm(
grad_output_fp8_reshaped.t(),
input_fp8_reshaped,
)

return grad_input, grad_weight.t()


@torch._dynamo.allow_in_graph
class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
"""
Expand Down Expand Up @@ -393,7 +449,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
input_fp8 = self.cast_input_to_float8(input, self.is_amax_initialized)
weight_fp8 = self.cast_weight_to_float8(self.weight, self.is_amax_initialized)

output = torch.matmul(input_fp8, weight_fp8.t())
output = manual_float8_matmul.apply(input_fp8, weight_fp8.t())

# Cast grad_output to float8_e5m2 during backward
output = self.cast_output_to_float8_in_bw(output)
Expand Down

0 comments on commit 113dbdf

Please sign in to comment.