Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
[TBD if for land] bring back torch.autograd.Function
Browse files Browse the repository at this point in the history
Summary:

This approach is more readable as we add additional scaling options.

For now, seeing how many things break in 2024-07 with
torch.autograd.Function + subclasses + compile.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 09c4625b2a859ce6468bac328d5f0ff61bb86251
Pull Request resolved: #316
  • Loading branch information
vkuzo committed Jul 16, 2024
1 parent de93990 commit 352f29e
Showing 1 changed file with 101 additions and 1 deletion.
102 changes: 101 additions & 1 deletion float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,101 @@ def _maybe_initialize_amaxes_scales_for_float8_cast(
)
scale.copy_(new_scale)

# this code was resurrected from https://github.com/pytorch-labs/float8_experimental/pull/128/files
# and modified to only support dynamic scaling
@torch._dynamo.allow_in_graph
class float8_linear(torch.autograd.Function):
"""
Like F.linear, but with X and W in float8
"""

@staticmethod
def forward(
ctx,
x_fp8,
w_fp8,
emulate: bool,
# TODO(this PR): split config into fwd/bwd
mm_config: ScaledMMConfig,
):
ctx.save_for_backward(x_fp8, w_fp8)
ctx.emulate = emulate
ctx.mm_config = mm_config
# orig_shape = x_fp8._data.shape
orig_shape = x_fp8.shape
# x_fp8_reshaped = Float8Tensor(
# x_fp8._data.reshape(-1, orig_shape[-1]), x_fp8._scale, x_fp8._orig_dtype, mm_config
# )
x_fp8_reshaped = x_fp8.reshape(-1, orig_shape[-1])

# w_fp8_t = Float8Tensor(w_fp8._data.t(), w_fp8._scale, w_fp8._orig_dtype, mm_config)
w_fp8_t = w_fp8.t()

res_bits = torch.mm(
x_fp8_reshaped, w_fp8_t
)
res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1])
return res_bits

@staticmethod
def backward(ctx, go_fp8):
x_fp8, w_fp8 = ctx.saved_tensors
emulate = ctx.emulate
mm_config = ctx.mm_config

go_fp8_orig_shape = go_fp8.shape
# go_fp8_reshaped = Float8Tensor(
# go_fp8._data.reshape(-1, go_fp8_orig_shape[-1]),
# go_fp8._scale,
# go_fp8._orig_dtype,
# mm_config,
# )
go_fp8_reshaped = go_fp8.reshape(-1, go_fp8_orig_shape[-1])

# w_fp8_t_c_t = Float8Tensor(
# w_fp8._data.t().contiguous().t(), w_fp8._scale, w_fp8._orig_dtype, mm_config
# )
w_fp8_t_c_t = w_fp8.t().contiguous().t()

#
# calculate dL/dX
#
dL_dX = torch.mm(
go_fp8_reshaped,
w_fp8_t_c_t,
)
dL_dX = dL_dX.reshape(*go_fp8_orig_shape[:-1], dL_dX.shape[-1])

# x_fp8_orig_shape = x_fp8._data.shape
x_fp8_orig_shape = x_fp8.shape
# x_fp8_reshaped_t_c = Float8Tensor(
# x_fp8._data.reshape(-1, x_fp8_orig_shape[-1]).t().contiguous(),
# x_fp8._scale,
# x_fp8._orig_dtype,
# mm_config,
# )
x_fp8_reshaped_t_c = x_fp8.reshape(-1, x_fp8_orig_shape[-1]).t().contiguous()

# go_fp8_reshaped_t_c_t = Float8Tensor(
# go_fp8_reshaped._data.t().contiguous().t(),
# go_fp8_reshaped._scale,
# go_fp8_reshaped._orig_dtype,
# mm_config,
# )
go_fp8_reshaped_t_c_t = go_fp8_reshaped.t().contiguous().t()

#
# calculate dL/dW
#
dL_dW = torch.mm(
x_fp8_reshaped_t_c,
go_fp8_reshaped_t_c_t,
)
dL_dW = dL_dW.t()

empty_grads = None, None, None, None, None, None, None, None, None
return dL_dX, dL_dW, *empty_grads


@torch._dynamo.allow_in_graph
class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
Expand Down Expand Up @@ -394,7 +489,12 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
x_fp8 = self.cast_x_to_float8(input, self.is_amax_initialized)
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)

y = torch.matmul(x_fp8, w_fp8.t())
if not self.has_any_delayed_scaling:
emulate = False
mm_config = self.forward_config
y = float8_linear.apply(x_fp8, w_fp8, emulate, mm_config)
else:
y = torch.matmul(x_fp8, w_fp8.t())

# Cast gradY to float8_e5m2 during backward
y = self.cast_y_to_float8_in_bw(y)
Expand Down

0 comments on commit 352f29e

Please sign in to comment.