diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 37cf0d5..4823cee 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -71,6 +71,58 @@ def _maybe_initialize_amaxes_scales_for_float8_cast( scale.copy_(new_scale) +# this code was resurrected from https://github.com/pytorch-labs/float8_experimental/pull/128/files +@torch._dynamo.allow_in_graph +class manual_float8_matmul(torch.autograd.Function): + """ + Like torch.matmul, but with the arguments in float8 + """ + + @staticmethod + def forward( + ctx, + x_fp8, + w_fp8_t, + ): + ctx.save_for_backward(x_fp8, w_fp8_t) + # the reshapes are needed in order to make the shapes compatible with + # torch.mm + orig_shape = x_fp8.shape + x_fp8_reshaped = x_fp8.reshape(-1, orig_shape[-1]) + res_bits = torch.mm(x_fp8_reshaped, w_fp8_t) + res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1]) + return res_bits + + @staticmethod + def backward(ctx, go_fp8): + x_fp8, w_fp8_t = ctx.saved_tensors + + # the reshapes are needed in order to make the shapes compatible with + # torch.mm + go_fp8_orig_shape = go_fp8.shape + go_fp8_reshaped = go_fp8.reshape(-1, go_fp8_orig_shape[-1]) + + # calculate dL/dX + dL_dX = torch.mm( + go_fp8_reshaped, + w_fp8_t.t(), + ) + dL_dX = dL_dX.reshape(*go_fp8_orig_shape[:-1], dL_dX.shape[-1]) + + x_fp8_orig_shape = x_fp8.shape + x_fp8_reshaped = x_fp8.reshape(-1, x_fp8_orig_shape[-1]) + + # calculate dL/dW + # Note: the variant below is slightly faster on LLaMa 3 8B pretraining + # compared to than calculating `dL_dW_t = x_fp8_t @ go_fp8_reshaped` + dL_dW = torch.mm( + go_fp8_reshaped.t(), + x_fp8_reshaped, + ) + + return dL_dX, dL_dW.t() + + @torch._dynamo.allow_in_graph class NoopFwToFloat8E5M2Bw(torch.autograd.Function): """ @@ -410,7 +462,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: x_fp8 = self.cast_x_to_float8(input, self.is_amax_initialized) w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized) - y = torch.matmul(x_fp8, w_fp8.t()) + y = manual_float8_matmul.apply(x_fp8, w_fp8.t()) # Cast gradY to float8_e5m2 during backward y = self.cast_y_to_float8_in_bw(y) diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 475a17a..491dc24 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -101,7 +101,9 @@ def choose_scaled_mm_config( a_linear_mm_config.dL_dX == b_linear_mm_config.dL_dX ), f"linear_mm_config.dL_dX mismatch: {a_linear_mm_config.dL_dX} vs {b_linear_mm_config.dL_dX}" return a_linear_mm_config.dL_dX - elif a_role is GemmInputRole.DL_DY and b_role is GemmInputRole.X: + elif (a_role is GemmInputRole.X and b_role is GemmInputRole.DL_DY) or ( + a_role is GemmInputRole.DL_DY and b_role is GemmInputRole.X + ): assert ( a_linear_mm_config.dL_dW == b_linear_mm_config.dL_dW ), f"linear_mm_config.dL_dW mismatch: {a_linear_mm_config.dL_dW} vs {b_linear_mm_config.dL_dW}"