Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
[TBD if for land] bring back torch.autograd.Function
Browse files Browse the repository at this point in the history
Summary:

This approach is more readable as we add additional scaling options.

For now, seeing how many things break in 2024-07 with
torch.autograd.Function + subclasses + compile.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: a7abb00cce87d18273d3bb18996eebb2bb0c4c99
Pull Request resolved: #316
  • Loading branch information
vkuzo committed Jul 22, 2024
1 parent c58fb5d commit 41a6395
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 2 deletions.
54 changes: 53 additions & 1 deletion float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,58 @@ def _maybe_initialize_amaxes_scales_for_float8_cast(
scale.copy_(new_scale)


# this code was resurrected from https://github.com/pytorch-labs/float8_experimental/pull/128/files
@torch._dynamo.allow_in_graph
class manual_float8_matmul(torch.autograd.Function):
"""
Like torch.matmul, but with the arguments in float8
"""

@staticmethod
def forward(
ctx,
x_fp8,
w_fp8_t,
):
ctx.save_for_backward(x_fp8, w_fp8_t)
# the reshapes are needed in order to make the shapes compatible with
# torch.mm
orig_shape = x_fp8.shape
x_fp8_reshaped = x_fp8.reshape(-1, orig_shape[-1])
res_bits = torch.mm(x_fp8_reshaped, w_fp8_t)
res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1])
return res_bits

@staticmethod
def backward(ctx, go_fp8):
x_fp8, w_fp8_t = ctx.saved_tensors

# the reshapes are needed in order to make the shapes compatible with
# torch.mm
go_fp8_orig_shape = go_fp8.shape
go_fp8_reshaped = go_fp8.reshape(-1, go_fp8_orig_shape[-1])

# calculate dL/dX
dL_dX = torch.mm(
go_fp8_reshaped,
w_fp8_t.t(),
)
dL_dX = dL_dX.reshape(*go_fp8_orig_shape[:-1], dL_dX.shape[-1])

x_fp8_orig_shape = x_fp8.shape
x_fp8_reshaped = x_fp8.reshape(-1, x_fp8_orig_shape[-1])

# calculate dL/dW
# Note: the variant below is slightly faster on LLaMa 3 8B pretraining
# compared to than calculating `dL_dW_t = x_fp8_t @ go_fp8_reshaped`
dL_dW = torch.mm(
go_fp8_reshaped.t(),
x_fp8_reshaped,
)

return dL_dX, dL_dW.t()


@torch._dynamo.allow_in_graph
class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
"""
Expand Down Expand Up @@ -410,7 +462,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
x_fp8 = self.cast_x_to_float8(input, self.is_amax_initialized)
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)

y = torch.matmul(x_fp8, w_fp8.t())
y = manual_float8_matmul.apply(x_fp8, w_fp8.t())

# Cast gradY to float8_e5m2 during backward
y = self.cast_y_to_float8_in_bw(y)
Expand Down
4 changes: 3 additions & 1 deletion float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def choose_scaled_mm_config(
a_linear_mm_config.dL_dX == b_linear_mm_config.dL_dX
), f"linear_mm_config.dL_dX mismatch: {a_linear_mm_config.dL_dX} vs {b_linear_mm_config.dL_dX}"
return a_linear_mm_config.dL_dX
elif a_role is GemmInputRole.DL_DY and b_role is GemmInputRole.X:
elif (a_role is GemmInputRole.X and b_role is GemmInputRole.DL_DY) or (
a_role is GemmInputRole.DL_DY and b_role is GemmInputRole.X
):
assert (
a_linear_mm_config.dL_dW == b_linear_mm_config.dL_dW
), f"linear_mm_config.dL_dW mismatch: {a_linear_mm_config.dL_dW} vs {b_linear_mm_config.dL_dW}"
Expand Down

0 comments on commit 41a6395

Please sign in to comment.