diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 7850738..16c1257 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -69,6 +69,66 @@ def _maybe_initialize_amaxes_scales_for_float8_cast( scale.copy_(new_scale) +# this code was resurrected from https://github.com/pytorch-labs/float8_experimental/pull/128/files +# and modified to only support dynamic scaling +@torch._dynamo.allow_in_graph +class float8_mm(torch.autograd.Function): + """ + Like torch.mm, but with X and W in float8 + """ + + @staticmethod + def forward( + ctx, + x_fp8, + w_fp8, + ): + ctx.save_for_backward(x_fp8, w_fp8) + orig_shape = x_fp8.shape + x_fp8_reshaped = x_fp8.reshape(-1, orig_shape[-1]) + + w_fp8_t = w_fp8.t() + + res_bits = torch.mm(x_fp8_reshaped, w_fp8_t) + res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1]) + return res_bits + + @staticmethod + def backward(ctx, go_fp8): + x_fp8, w_fp8 = ctx.saved_tensors + + go_fp8_orig_shape = go_fp8.shape + go_fp8_reshaped = go_fp8.reshape(-1, go_fp8_orig_shape[-1]) + + w_fp8_t_c_t = w_fp8.t().contiguous().t() + + # + # calculate dL/dX + # + dL_dX = torch.mm( + go_fp8_reshaped, + w_fp8_t_c_t, + ) + dL_dX = dL_dX.reshape(*go_fp8_orig_shape[:-1], dL_dX.shape[-1]) + + x_fp8_orig_shape = x_fp8.shape + x_fp8_reshaped_t_c = x_fp8.reshape(-1, x_fp8_orig_shape[-1]).t().contiguous() + + go_fp8_reshaped_t_c_t = go_fp8_reshaped.t().contiguous().t() + + # + # calculate dL/dW + # + dL_dW = torch.mm( + x_fp8_reshaped_t_c, + go_fp8_reshaped_t_c_t, + ) + dL_dW = dL_dW.t() + + empty_grads = (None,) + return dL_dX, dL_dW, *empty_grads + + @torch._dynamo.allow_in_graph class NoopFwToFloat8E5M2Bw(torch.autograd.Function): """ @@ -394,7 +454,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: x_fp8 = self.cast_x_to_float8(input, self.is_amax_initialized) w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized) - y = torch.matmul(x_fp8, w_fp8.t()) + # y = float8_mm.apply(x_fp8, w_fp8) + y = float8_mm.apply(x_fp8, w_fp8) # Cast gradY to float8_e5m2 during backward y = self.cast_y_to_float8_in_bw(y)