From 352f29ee77c2a8e977e53f3d07385393f8bec326 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Tue, 16 Jul 2024 09:18:39 -0700 Subject: [PATCH] [TBD if for land] bring back torch.autograd.Function Summary: This approach is more readable as we add additional scaling options. For now, seeing how many things break in 2024-07 with torch.autograd.Function + subclasses + compile. Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 09c4625b2a859ce6468bac328d5f0ff61bb86251 Pull Request resolved: https://github.com/pytorch-labs/float8_experimental/pull/316 --- float8_experimental/float8_linear.py | 102 ++++++++++++++++++++++++++- 1 file changed, 101 insertions(+), 1 deletion(-) diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 7850738..787a54c 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -68,6 +68,101 @@ def _maybe_initialize_amaxes_scales_for_float8_cast( ) scale.copy_(new_scale) +# this code was resurrected from https://github.com/pytorch-labs/float8_experimental/pull/128/files +# and modified to only support dynamic scaling +@torch._dynamo.allow_in_graph +class float8_linear(torch.autograd.Function): + """ + Like F.linear, but with X and W in float8 + """ + + @staticmethod + def forward( + ctx, + x_fp8, + w_fp8, + emulate: bool, + # TODO(this PR): split config into fwd/bwd + mm_config: ScaledMMConfig, + ): + ctx.save_for_backward(x_fp8, w_fp8) + ctx.emulate = emulate + ctx.mm_config = mm_config + # orig_shape = x_fp8._data.shape + orig_shape = x_fp8.shape + # x_fp8_reshaped = Float8Tensor( + # x_fp8._data.reshape(-1, orig_shape[-1]), x_fp8._scale, x_fp8._orig_dtype, mm_config + # ) + x_fp8_reshaped = x_fp8.reshape(-1, orig_shape[-1]) + + # w_fp8_t = Float8Tensor(w_fp8._data.t(), w_fp8._scale, w_fp8._orig_dtype, mm_config) + w_fp8_t = w_fp8.t() + + res_bits = torch.mm( + x_fp8_reshaped, w_fp8_t + ) + res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1]) + return res_bits + + @staticmethod + def backward(ctx, go_fp8): + x_fp8, w_fp8 = ctx.saved_tensors + emulate = ctx.emulate + mm_config = ctx.mm_config + + go_fp8_orig_shape = go_fp8.shape + # go_fp8_reshaped = Float8Tensor( + # go_fp8._data.reshape(-1, go_fp8_orig_shape[-1]), + # go_fp8._scale, + # go_fp8._orig_dtype, + # mm_config, + # ) + go_fp8_reshaped = go_fp8.reshape(-1, go_fp8_orig_shape[-1]) + + # w_fp8_t_c_t = Float8Tensor( + # w_fp8._data.t().contiguous().t(), w_fp8._scale, w_fp8._orig_dtype, mm_config + # ) + w_fp8_t_c_t = w_fp8.t().contiguous().t() + + # + # calculate dL/dX + # + dL_dX = torch.mm( + go_fp8_reshaped, + w_fp8_t_c_t, + ) + dL_dX = dL_dX.reshape(*go_fp8_orig_shape[:-1], dL_dX.shape[-1]) + + # x_fp8_orig_shape = x_fp8._data.shape + x_fp8_orig_shape = x_fp8.shape + # x_fp8_reshaped_t_c = Float8Tensor( + # x_fp8._data.reshape(-1, x_fp8_orig_shape[-1]).t().contiguous(), + # x_fp8._scale, + # x_fp8._orig_dtype, + # mm_config, + # ) + x_fp8_reshaped_t_c = x_fp8.reshape(-1, x_fp8_orig_shape[-1]).t().contiguous() + + # go_fp8_reshaped_t_c_t = Float8Tensor( + # go_fp8_reshaped._data.t().contiguous().t(), + # go_fp8_reshaped._scale, + # go_fp8_reshaped._orig_dtype, + # mm_config, + # ) + go_fp8_reshaped_t_c_t = go_fp8_reshaped.t().contiguous().t() + + # + # calculate dL/dW + # + dL_dW = torch.mm( + x_fp8_reshaped_t_c, + go_fp8_reshaped_t_c_t, + ) + dL_dW = dL_dW.t() + + empty_grads = None, None, None, None, None, None, None, None, None + return dL_dX, dL_dW, *empty_grads + @torch._dynamo.allow_in_graph class NoopFwToFloat8E5M2Bw(torch.autograd.Function): @@ -394,7 +489,12 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: x_fp8 = self.cast_x_to_float8(input, self.is_amax_initialized) w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized) - y = torch.matmul(x_fp8, w_fp8.t()) + if not self.has_any_delayed_scaling: + emulate = False + mm_config = self.forward_config + y = float8_linear.apply(x_fp8, w_fp8, emulate, mm_config) + else: + y = torch.matmul(x_fp8, w_fp8.t()) # Cast gradY to float8_e5m2 during backward y = self.cast_y_to_float8_in_bw(y)