From 6ecba553d012755d694db4f82b29193b3a4eb633 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 26 Jul 2024 13:35:54 -0700 Subject: [PATCH] [wip] add axiswise granularity to Float8Tensor Summary: This PR adds the axiswise scaling granularity to `Float8Tensor` and ensures that basic ops like transpose and `torch._scaled_mm` work as expected. A future PR will add integration with `Float8Linear`. Test Plan: TODO Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 23f0f0c79d47ddeb8c929dc13583d04436b34b35 Pull Request resolved: https://github.com/pytorch-labs/float8_experimental/pull/351 --- float8_experimental/config.py | 12 +++++ float8_experimental/float8_ops.py | 40 ++++++++++++++++- float8_experimental/float8_python_api.py | 8 ++++ float8_experimental/float8_scaling_utils.py | 14 +++++- float8_experimental/float8_tensor.py | 13 +++--- float8_experimental/float8_utils.py | 30 ++++++++++--- test/test_base.py | 50 ++++++++++++++++++++- 7 files changed, 151 insertions(+), 16 deletions(-) diff --git a/float8_experimental/config.py b/float8_experimental/config.py index 5d1bf9f..217fca1 100644 --- a/float8_experimental/config.py +++ b/float8_experimental/config.py @@ -21,6 +21,18 @@ def short_str(self): return "dyn" +class ScalingGranularity(enum.Enum): + """ + Defines the granularity of scaling strategies for casting to float8 + """ + + # A single scaling factor for the entire tensor + TENSORWISE = "tensorwise" + # Scaling factors computed along one axis of the tensor, reducing it to + # size 1. + AXISWISE = "axiswise" + + @dataclass(frozen=True) class CastConfig: """ diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index 2a11726..588d48a 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -19,6 +19,15 @@ FLOAT8_OPS_TABLE: Dict[Any, Any] = {} +def _assert_tensorwise_scale(aten_op, scale): + assert ( + # TODO(future PR): figure out why tensorwise scaling can have + # both rank 0 and rank 1 + len(scale.shape) + in (0, 1) + ), f"{aten_op} with axiswise scaling is not supported yet" + + def implements(aten_ops): """Register aten ops to the float8 op table""" @@ -34,16 +43,15 @@ def decorator(func): [ aten.view.default, aten._unsafe_view.default, - aten.t.default, aten.as_strided.default, aten.clone.default, aten.detach.default, aten.slice.Tensor, - aten.transpose.int, aten.fill_.Scalar, ] ) def float8_desugar_op(aten_op, args, kwargs=None): + _assert_tensorwise_scale(aten_op, args[0]._scale) new_data = aten_op(args[0]._data, *args[1:], **kwargs) return Float8Tensor( new_data, @@ -54,8 +62,27 @@ def float8_desugar_op(aten_op, args, kwargs=None): ) +@implements( + [ + aten.t.default, + aten.transpose.int, + ] +) +def float8_desugar_data_and_scale(aten_op, args, kwargs=None): + new_data = aten_op(args[0]._data, *args[1:], **kwargs) + new_scale = aten_op(args[0]._scale, *args[1:], **kwargs) + return Float8Tensor( + new_data, + new_scale, + args[0]._orig_dtype, + args[0]._linear_mm_config, + args[0]._gemm_input_role, + ) + + @implements([aten.split.Tensor]) def float8_split(aten_op, args, kwargs=None): + _assert_tensorwise_scale(aten_op, args[0]._scale) new_data_tensors = aten_op(args[0]._data, *args[1:], **kwargs) def make_float8(data): @@ -101,6 +128,7 @@ def float8_cat(aten_op, args, kwargs=None): assert ( chunk._gemm_input_role is gemm_input_role ), "Expecting all chunks to have the same gemm_input_role as a result of a split" + _assert_tensorwise_scale(aten_op, chunk._scale) chunk_data.append(chunk._data.view(torch.uint8)) new_data = aten_op(chunk_data, *args[1:], **kwargs) @@ -117,6 +145,7 @@ def float8_cast_up_op(aten_op, args, kwargs=None): "addmm" -> out "hp_gradBias" <-"sum" <- "identity" <- gradOut <- "hp_gradOut" """ + _assert_tensorwise_scale(aten_op, args[0]._scale) def unwrap(x): if isinstance(x, Float8Tensor): @@ -229,6 +258,7 @@ def float8_addmm(aten_op, args, kwargs=None): @implements([aten.is_same_size.default]) def float8_is_same_size(aten_op, args, kwargs=None): + _assert_tensorwise_scale(aten_op, args[0]._scale) return args[0].shape == args[1].shape @@ -238,6 +268,7 @@ def autocast_to_copy(aten_op, args, kwargs=None): when the input is a Float8Tensor, presenting as a fp32 tensor. """ + _assert_tensorwise_scale(aten_op, args[0]._scale) assert isinstance(args[0], Float8Tensor) assert ( len(kwargs) == 1 and "dtype" in kwargs @@ -265,6 +296,7 @@ def allgather_fp8(aten_op, args, kwargs=None): """ override funcol with FP8 handling """ + _assert_tensorwise_scale(aten_op, args[0]._scale) fp8_input = args[0] assert isinstance( fp8_input, Float8Tensor @@ -284,6 +316,7 @@ def allgather_fp8(aten_op, args, kwargs=None): @implements([c10d_functional.wait_tensor.default, _c10d_functional.wait_tensor.default]) def wait_tensor_fp8(aten_op, args, kwargs=None): + _assert_tensorwise_scale(aten_op, args[0]._scale) fp8_input = args[0] assert isinstance(fp8_input, Float8Tensor) @@ -304,6 +337,7 @@ def index_put_fp8(aten_op, args, kwargs=None): fp8_values = args[2] assert isinstance(fp8_self, Float8Tensor) assert isinstance(fp8_values, Float8Tensor) + _assert_tensorwise_scale(fp8_self, args[0]._scale) assert fp8_self._scale == fp8_values._scale assert fp8_self.dtype == fp8_values.dtype assert fp8_self._orig_dtype == fp8_values._orig_dtype @@ -334,8 +368,10 @@ def copy_fp8(aten_op, args, kwargs=None): if not isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor): src_hp = src.to_original_precision() + _assert_tensorwise_scale(aten_op, src._scale) return aten_op(self, src_hp, *args[2:], **kwargs) elif isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor): + _assert_tensorwise_scale(aten_op, src._scale) assert ( self._orig_dtype == src._orig_dtype ), "Expecting both Float8Tensors to be of the same dtype" diff --git a/float8_experimental/float8_python_api.py b/float8_experimental/float8_python_api.py index d8aa081..001eff4 100644 --- a/float8_experimental/float8_python_api.py +++ b/float8_experimental/float8_python_api.py @@ -38,6 +38,14 @@ def addmm_float8_unwrapped( """ a_inverse_scale = a_scale.reciprocal() b_inverse_scale = b_scale.reciprocal() + + # TODO: should we change torch._scaled_mm? + # torch._scaled_mm expects rowwise scaled scales to be of rank 1, not rank + # 2. Translate to this format. + # TODO: audit if we need to make this more generic for various shapes. + a_inverse_scale = a_inverse_scale.squeeze() + b_inverse_scale = b_inverse_scale.squeeze() + if output_dtype == torch.float32 and bias is not None: # Bias is not supported by _scaled_mm when output is fp32 output = torch._scaled_mm( diff --git a/float8_experimental/float8_scaling_utils.py b/float8_experimental/float8_scaling_utils.py index ce6422f..06c93c1 100644 --- a/float8_experimental/float8_scaling_utils.py +++ b/float8_experimental/float8_scaling_utils.py @@ -12,6 +12,8 @@ import torch +from float8_experimental.config import ScalingGranularity + from float8_experimental.float8_tensor import ( Float8Tensor, GemmInputRole, @@ -36,6 +38,8 @@ def hp_tensor_to_float8_dynamic( linear_mm_config: LinearMMConfig, reduce_amax: bool = False, gemm_input_role: GemmInputRole = GemmInputRole.INPUT, + scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, + axiswise_dim: Optional[int] = None, ) -> Float8Tensor: """ Given a high precision tensor `hp_tensor`, @@ -49,10 +53,18 @@ def hp_tensor_to_float8_dynamic( reduce_amax: whether to reduce the max(abs(hp_tensor)) value across distributed ranks gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in the 3 fwd/bwd gemms of linear + scaling_granularity: Defines the scaling granularity + axiswise_dim: if axiswise granularity is used, defines the dim to scale across """ if tensor_already_casted_to_fp8(hp_tensor): return hp_tensor - scale = tensor_to_scale(hp_tensor, float8_dtype, reduce_amax) + scale = tensor_to_scale( + hp_tensor, + float8_dtype, + reduce_amax, + scaling_granularity, + axiswise_dim, + ) return hp_tensor_and_scale_to_float8( hp_tensor, scale, diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 641f972..22c2a32 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -250,7 +250,12 @@ class Float8Tensor(torch.Tensor): * `_data`: the underlying e4m3 or e5m2 data * `_scale`: the scale used to scale the original fp32 tensor. We multiply by scale to go from fp32 range to fp8 range, and divide by scale to go - from fp8 range to fp32 range. + from fp8 range to fp32 range. Scale is guaranteed to have a shape compatible + with `_data`. For example: + - if scaling is tensorwise, `_scale` is a scalar tensor + - if scaling is axiswise and _data.shape is [3, 5], `_scale` could have + shape [1, 5] or [5, 1]. The dim of the non-one entry defines the scaling + axis. * `_orig_dtype`: the original dtype of the tensor used to create this tensor. * `_emulate`: if true using fp32 emulation for the matmuls, helpful @@ -279,12 +284,6 @@ def __new__( linear_mm_config: Optional[LinearMMConfig], gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, ): - assert ( - scale.numel() == 1 - ), "Scale should contain a single value, but got: {} elements".format( - scale.numel() - ) - self = torch.Tensor._make_wrapper_subclass( cls, data.size(), diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index 2be568e..26fde8a 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -4,12 +4,13 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from typing import Iterable, Literal, Tuple, Union +from typing import Iterable, Literal, Optional, Tuple, Union import float8_experimental.config as config import torch import torch.distributed as dist +from float8_experimental.config import ScalingGranularity # Helpful visualizer for debugging (only supports fp32): # https://www.h-schmidt.net/FloatConverter/IEEE754.html @@ -100,8 +101,23 @@ def amax_history_to_scale_stack( @torch.no_grad() -def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor: - amax = torch.max(torch.abs(x)) +def tensor_to_amax( + x: torch.Tensor, + reduce_amax: bool = False, + scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, + axiswise_dim: Optional[int] = None, +) -> torch.Tensor: + if scaling_granularity is ScalingGranularity.TENSORWISE: + amax = torch.max(torch.abs(x)) + else: + assert scaling_granularity is ScalingGranularity.AXISWISE, "unsupported" + assert axiswise_dim is not None, "unsupported" + + # convert from axiswise_dim (dim to keep) to + # dim as the input to the `torch.amax` function (tuple of dims to reduce) + dim_to_reduce = tuple(d for d in range(len(x.shape)) if d != axiswise_dim) + + amax = torch.amax(torch.abs(x), dim=dim_to_reduce, keepdim=True) # If the user asked for distributed reduction, do it. # If the user did not ask for it, assume that it will @@ -114,9 +130,13 @@ def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor: @torch.no_grad() def tensor_to_scale( - x: torch.Tensor, float8_dtype: torch.dtype, reduce_amax: bool = False + x: torch.Tensor, + float8_dtype: torch.dtype, + reduce_amax: bool = False, + scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, + axiswise_dim: Optional[int] = None, ) -> torch.Tensor: - amax = tensor_to_amax(x, reduce_amax=reduce_amax) + amax = tensor_to_amax(x, reduce_amax, scaling_granularity, axiswise_dim) return amax_to_scale(amax, float8_dtype, x.dtype) diff --git a/test/test_base.py b/test/test_base.py index 4e0c685..38fed52 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -16,7 +16,12 @@ import torch import torch.nn as nn -from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType +from float8_experimental.config import ( + CastConfig, + Float8LinearConfig, + ScalingGranularity, + ScalingType, +) from float8_experimental.float8_linear import Float8Linear from float8_experimental.float8_linear_utils import ( convert_to_float8_training, @@ -24,6 +29,7 @@ sync_float8_amax_and_scale_history, ) from float8_experimental.float8_python_api import addmm_float8_unwrapped +from float8_experimental.float8_scaling_utils import hp_tensor_to_float8_dynamic from float8_experimental.float8_tensor import ( Float8Tensor, GemmInputRole, @@ -143,6 +149,48 @@ def test_weights_only_load(self): buffer.seek(0) _ = torch.load(buffer, weights_only=True) + def test_axiswise_dynamic_cast(self): + a = torch.randn(16, 32, dtype=torch.bfloat16) + linear_mm_config = LinearMMConfig() + a_fp8 = hp_tensor_to_float8_dynamic( + a, + e4m3_dtype, + linear_mm_config, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=0, + ) + # print(a_fp8) + # print(a_fp8.to_original_precision()) + # print(a_fp8.t()) + b = a_fp8.t() + # TODO check numerical accuracy + + def test_axiswise_gemm(self): + a = torch.randn(16, 32, dtype=torch.bfloat16, device="cuda") + b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda") + + linear_mm_config = LinearMMConfig() + + a_fp8 = hp_tensor_to_float8_dynamic( + a, + e4m3_dtype, + linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=0, + ) + b_fp8 = hp_tensor_to_float8_dynamic( + b, + e4m3_dtype, + linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=0, + ) + c = torch.mm(a_fp8, b_fp8.t()) + print(c) + # TODO check numerical accuracy + class TestFloat8Linear: def _test_linear_impl(