From 701647ba0bbba3da272af944d09f8adbf5c62e00 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Thu, 25 Jul 2024 14:34:27 -0700 Subject: [PATCH 1/6] Reduced CPU overhead in `precompute_float8_dynamic_scale_for_fsdp` (#331) Summary: Pull Request resolved: https://github.com/pytorch-labs/float8_experimental/pull/331 **Description** For Llama3-8B on 8xH100 profiling with `with_stack=True` (which does add overhead), the `precompute_float8_dynamic_scale_for_fsdp` CPU time decreases from 24 ms to 15 ms. Before: Screenshot 2024-07-25 at 10 16 38 AM After: Screenshot 2024-07-25 at 10 17 00 AM **Test Plan** ``` (pytorch-3.10) [andgu@devgpu011.cco1 /data/users/andgu/float8_experimental (precompute_float8)]$ pytest test/test_fsdp2/test_fsdp2.py ========================================================= test session starts ========================================================= platform linux -- Python 3.10.13, pytest-7.3.2, pluggy-1.3.0 rootdir: /data/users/andgu/float8_experimental plugins: xdoctest-1.1.0, hypothesis-5.35.1, xdist-3.3.1, shard-0.1.2, rerunfailures-13.0, flakefinder-1.1.0, cpp-2.3.0 collected 8 items Running 8 items in this shard test/test_fsdp2/test_fsdp2.py ........ [100%] ========================================================== warnings summary =========================================================== test/test_fsdp2/test_fsdp2.py::TestFloat8MultiThread::test_fp32_fp8_multi_module_parity test/test_fsdp2/test_fsdp2.py::TestFloat8MultiThread::test_fp32_fp8_single_module_parity /data/users/andgu/float8_experimental/float8_experimental/float8_linear_utils.py:272: FutureWarning: The combination of ranks + tag as process group identifier has been deprecated. Please switch to using ProcessGroup, DeviceMesh, or group name instead. all_reduced_amax_tensor = all_reduce( -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html ============================================== 8 passed, 2 warnings in 121.90s (0:02:01) ============================================== ``` imported-using-ghimport Test Plan: Imported from OSS Reviewed By: weifengpy Differential Revision: D60236258 Pulled By: awgu fbshipit-source-id: 7b1e48d431dac25d534a77d64d1e5571ad3ad807 --- float8_experimental/fsdp_utils.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py index d9fd200..eef9ec1 100644 --- a/float8_experimental/fsdp_utils.py +++ b/float8_experimental/fsdp_utils.py @@ -57,18 +57,16 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: # inf-norm is equivalent to max(abs(w)) max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial - amax_tensor = torch.vstack(max_weights) # Partial + amax_tensor = torch.stack(max_weights) # Partial # clamp is dispatched through DTensor # it will issue a single all-reduce amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate if amax_tensor.dtype is torch.float16: scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) - scales = torch.split(scale_tensor, 1) # Replicate - for scale, float8_linear in zip(scales, float8_linears): - float8_linear.weight._local_tensor._precomputed_scale = ( - scale._local_tensor.squeeze() - ) + local_scale_tensor = scale_tensor.to_local() + for i, float8_linear in enumerate(float8_linears): + float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i] # FSDP pads its local tensor on dim-0. The subclass should be preserved such From eff4ba60570b5e69899860b14688e71e339783dd Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Thu, 25 Jul 2024 14:48:12 -0700 Subject: [PATCH 2/6] rename `config.enable_fsdp_fp8_all_gather` to use `float8` (#332) Summary: Pull Request resolved: https://github.com/pytorch-labs/float8_experimental/pull/332 old: `enable_fsdp_fp8_all_gather` new: `enable_fsdp_float8_all_gather` this is to match the `float8` naming elsewhere Reviewed By: weifengpy Differential Revision: D60252072 fbshipit-source-id: 5e240f0a97b647aa4f43a63dab3f03f68fd3b405 --- float8_experimental/config.py | 7 +++-- float8_experimental/float8_linear.py | 2 +- test/test_fsdp2/test_fsdp2.py | 40 ++++++++++++++-------------- 3 files changed, 24 insertions(+), 25 deletions(-) diff --git a/float8_experimental/config.py b/float8_experimental/config.py index e190b2f..2e9eacf 100644 --- a/float8_experimental/config.py +++ b/float8_experimental/config.py @@ -58,10 +58,9 @@ class Float8LinearConfig: # option is useful for safety, but not strictly necessary. enable_pre_and_post_forward: bool = True - # If True, then uses a tensor subclass for the fp8 linear module's weight that - # implements pre/post-all-gather methods to do fp8 all-gather with FSDP2. - # Only dynamic scaling is supported for now. - enable_fsdp_fp8_all_gather: bool = False + # If True, then uses a tensor subclass for the float8 linear module's weight that + # implements pre/post-all-gather methods to do float8 all-gather with FSDP2. + enable_fsdp_float8_all_gather: bool = False # If True, then prior to performing the fp8 scaled mamtmul we will pad the # inner dimension of a (dim 1) and b (dim 2) with 0s. This is needed for matmuls diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 38e10e5..1d8519e 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -467,7 +467,7 @@ def from_float( # 1. weight needs to be on the correct device to create the buffers # 2. buffers need to be already created for the delayed scaling version # of the weight wrapper to be initialized - if config.enable_fsdp_fp8_all_gather: + if config.enable_fsdp_float8_all_gather: if config.cast_config_weight.scaling_type is TensorScalingType.DYNAMIC: new_mod.weight = torch.nn.Parameter( WeightWithDynamicFloat8CastTensor( diff --git a/test/test_fsdp2/test_fsdp2.py b/test/test_fsdp2/test_fsdp2.py index f40ad25..6d5719a 100644 --- a/test/test_fsdp2/test_fsdp2.py +++ b/test/test_fsdp2/test_fsdp2.py @@ -83,7 +83,7 @@ def world_size(self) -> int: def test_transformer_parity(self): self.run_subtests( { - "enable_fsdp_fp8_all_gather": [False, True], + "enable_fsdp_float8_all_gather": [False, True], "precompute": [False, True], "scaling_type_weight": [ TensorScalingType.DYNAMIC, @@ -96,12 +96,12 @@ def test_transformer_parity(self): def _test_transformer_parity( self, - enable_fsdp_fp8_all_gather: bool, + enable_fsdp_float8_all_gather: bool, precompute: bool, scaling_type_weight: TensorScalingType, compile_transformer_block: bool, ): - if not enable_fsdp_fp8_all_gather and precompute: + if not enable_fsdp_float8_all_gather and precompute: return elif scaling_type_weight is TensorScalingType.DELAYED and precompute: return @@ -110,7 +110,7 @@ def _test_transformer_parity( # embedding weight and output linear weight are tied but only the # latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to # fp8 for that tied weight, incorrectly using fp8 for the embedding. - weight_tying = not enable_fsdp_fp8_all_gather + weight_tying = not enable_fsdp_float8_all_gather module = self.init_transformer(weight_tying=weight_tying).cuda() ref_module = copy.deepcopy(module) float8_linear_config1 = Float8LinearConfig( @@ -125,7 +125,7 @@ def _test_transformer_parity( transformer_block = torch.compile(transformer_block, dynamic=False) ref_module.layers.register_module(layer_id, transformer_block) float8_linear_config2 = Float8LinearConfig( - enable_fsdp_fp8_all_gather=enable_fsdp_fp8_all_gather, + enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight), ) convert_to_float8_training( @@ -158,10 +158,10 @@ def _test_transformer_parity( @skip_if_lt_x_gpu(2) def test_transformer_memory(self): """Tests peak active memory in the forward and backward passes.""" - for enable_fsdp_fp8_all_gather in [False, True]: - self._test_transformer_memory(enable_fsdp_fp8_all_gather) + for enable_fsdp_float8_all_gather in [False, True]: + self._test_transformer_memory(enable_fsdp_float8_all_gather) - def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool): + def _test_transformer_memory(self, enable_fsdp_float8_all_gather: bool): torch.manual_seed(42) # Pre-run a linear forward (gemm and bias) and backward (gemm) to # allocate the cuBLAS workspaces before measuring the memory usage @@ -184,7 +184,7 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool): # Emulate the fp8 matmul to bypass the scaled matmul op's divisibility # requirement to use a smaller activation size float8_linear_config = Float8LinearConfig( - enable_fsdp_fp8_all_gather=enable_fsdp_fp8_all_gather, + enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, emulate=True, ) convert_to_float8_training(model, config=float8_linear_config) @@ -231,7 +231,7 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool): # number is kept much smaller than the actual memory usage, which is on # the order of 100-200+ MB) buffer_mb = 16 - if enable_fsdp_fp8_all_gather: + if enable_fsdp_float8_all_gather: # Non-block parameters (fp32), 3x block non-linear-weight # parameters (fp32) and block linear-weight parameters (fp8) # (current all-gather, copy-out, and next all-gather), and other @@ -255,7 +255,7 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool): # Backward: loss.sum().backward() mem_mb = self._get_peak_active_memory_mb() - if enable_fsdp_fp8_all_gather: + if enable_fsdp_float8_all_gather: # Non-block parameters (fp32), 2x block non-linear weight # parameters (fp32) and block linear-weight parameters (fp8) # (current copy-out and next all-gather), 1x block gradients (fp32) @@ -294,7 +294,7 @@ def test_weight_subclass_dynamic(self): # Check for a single FSDP paramter group module_fp32 = self.init_single_module() float8_linear_config = Float8LinearConfig( - enable_fsdp_fp8_all_gather=True, + enable_fsdp_float8_all_gather=True, emulate=True, ) module = convert_to_float8_training( @@ -360,7 +360,7 @@ def get_expected_all_gather_size(module: nn.Module): module_fp32 = self.init_single_module() ref_module = copy.deepcopy(module_fp32) float8_linear_config = Float8LinearConfig( - enable_fsdp_fp8_all_gather=True, + enable_fsdp_float8_all_gather=True, ) module_fp32 = convert_to_float8_training( module_fp32, config=float8_linear_config @@ -418,15 +418,15 @@ def test_fp32_fp8_single_module_parity(self): [False, True], [TensorScalingType.DYNAMIC, TensorScalingType.DELAYED], ) - for enable_fsdp_fp8_all_gather, scaling_type_weight in choices: + for enable_fsdp_float8_all_gather, scaling_type_weight in choices: float8_linear_config1 = Float8LinearConfig( - enable_fsdp_fp8_all_gather=False, + enable_fsdp_float8_all_gather=False, cast_config_weight=Float8TensorCastConfig( scaling_type=scaling_type_weight ), ) float8_linear_config2 = Float8LinearConfig( - enable_fsdp_fp8_all_gather=enable_fsdp_fp8_all_gather, + enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, cast_config_weight=Float8TensorCastConfig( scaling_type=scaling_type_weight ), @@ -466,15 +466,15 @@ def test_fp32_fp8_multi_module_parity(self): [False, True], [TensorScalingType.DYNAMIC, TensorScalingType.DELAYED], ) - for enable_fsdp_fp8_all_gather, scaling_type_weight in choices: + for enable_fsdp_float8_all_gather, scaling_type_weight in choices: float8_linear_config1 = Float8LinearConfig( - enable_fsdp_fp8_all_gather=False, + enable_fsdp_float8_all_gather=False, cast_config_weight=Float8TensorCastConfig( scaling_type=scaling_type_weight ), ) float8_linear_config2 = Float8LinearConfig( - enable_fsdp_fp8_all_gather=enable_fsdp_fp8_all_gather, + enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, cast_config_weight=Float8TensorCastConfig( scaling_type=scaling_type_weight ), @@ -545,7 +545,7 @@ def test_delayed_scaling_inplace_update(self): """ module = self.init_single_module() float8_linear_config = Float8LinearConfig( - enable_fsdp_fp8_all_gather=True, + enable_fsdp_float8_all_gather=True, cast_config_weight=Float8TensorCastConfig( scaling_type=TensorScalingType.DELAYED ), From ed1693ec3b3deaab8f3591960d7b4b18b1dcb46c Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Thu, 25 Jul 2024 14:48:12 -0700 Subject: [PATCH 3/6] rename `DelayedScalingRecipe` to `DelayedScalingConfig` (#333) Summary: Pull Request resolved: https://github.com/pytorch-labs/float8_experimental/pull/333 1. rename `DelayedScalingRecipe` to `DelayedScalingConfig` 2. move this to `config.py` and make user facing Reviewed By: weifengpy Differential Revision: D60252067 fbshipit-source-id: ec233df1e0d03fdc649a19de1722ee45d5029aa6 --- float8_experimental/__init__.py | 2 ++ float8_experimental/config.py | 31 ++++++++++++++++++++ float8_experimental/float8_linear.py | 34 +++------------------- float8_experimental/float8_linear_utils.py | 2 +- 4 files changed, 38 insertions(+), 31 deletions(-) diff --git a/float8_experimental/__init__.py b/float8_experimental/__init__.py index 95491f3..4c1f255 100644 --- a/float8_experimental/__init__.py +++ b/float8_experimental/__init__.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. # Lets define a few top level things here from float8_experimental.config import ( + DelayedScalingConfig, Float8LinearConfig, Float8TensorCastConfig, TensorScalingType, @@ -30,6 +31,7 @@ __all__ = [ # configuration + "DelayedScalingConfig", "TensorScalingType", "Float8LinearConfig", "Float8TensorCastConfig", diff --git a/float8_experimental/config.py b/float8_experimental/config.py index 2e9eacf..ea088e3 100644 --- a/float8_experimental/config.py +++ b/float8_experimental/config.py @@ -29,6 +29,30 @@ class Float8TensorCastConfig: scaling_type: TensorScalingType = TensorScalingType.DYNAMIC +@dataclass(frozen=True) +class DelayedScalingConfig: + """ + Configuration for delayed scaling. + + Note: for now, `history_len` values must be the same for all layers in the + model using delayed scaling. + + TODO(future): serialization for recipes + """ + + # Controls the history length of amax buffers + history_len: int = 16 + + # Controls the way to calculate current scale from amax history + # TODO(future): add other functions as needed, hardcoded or user defined + scale_fn_name: str = "max" + + def __post_init__(self): + assert ( + self.scale_fn_name == "max" + ), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now." + + @dataclass(frozen=True) class Float8LinearConfig: """ @@ -71,6 +95,13 @@ class Float8LinearConfig: # If True, emulation is used instead of hardware accelerated gemm emulate: bool = False + # Configuration for delayed scaling + # Note: this is actually applied per-tensor, but only using the same + # configuration for all tensors and layers in the model is currently + # supported. If in the future we add support for a more fine grained + # configuration, this field may move to per-tensor configs. + delayed_scaling_config: DelayedScalingConfig = DelayedScalingConfig() + # If True, use 'fnuz' float8 types for calculations. # Currently, ROCm only supports fnuz variants. diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 1d8519e..581f9f3 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -131,23 +131,6 @@ def backward(ctx, go): return res, *empty_grads -@dataclasses.dataclass -class DelayedScalingRecipe: - # Controls the history length of amax buffers - history_len: int - - # Controls the way to calculate current scale from amax history - # TODO(future): add other functions as needed, hardcoded or user defined - scale_fn_name: str - - def __init__(self, history_len: int = 16, scale_fn_name: str = "max"): - self.history_len = history_len - self.scale_fn_name = scale_fn_name - assert ( - self.scale_fn_name == "max" - ), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now." - - class Float8Linear(torch.nn.Linear): """ Note: this is **not** a public API and is only intended to be used @@ -161,13 +144,9 @@ class Float8Linear(torch.nn.Linear): def __init__(self, *args, **kwargs): """ Additional arguments on top of `torch.nn.Linear`'s arguments: - * `delayed_scaling_recipe`: configuration for delayed scaling * `config`: Float8LinearConfig """ - delayed_scaling_recipe = kwargs.pop( - "delayed_scaling_recipe", DelayedScalingRecipe() - ) # Amax scales should always be kept as float32. self.always_float32_buffers = set() config = kwargs.pop("config") @@ -187,11 +166,6 @@ def __init__(self, *args, **kwargs): self.config = config - # TODO(future): have a unique recipe per buffer instead of one per - # module, saving implementing that until we need it. - # TODO(future): serialization for recipes - self.recipe = delayed_scaling_recipe - self.create_buffers() # TODO(future): user level configuration of gemms @@ -237,7 +211,7 @@ def __init__(self, *args, **kwargs): def create_buffers(self): # Default values for history buffers, see above TODO - history_len = self.recipe.history_len + history_len = self.config.delayed_scaling_config.history_len device = self.weight.device # TODO(future PR): dtype values below don't have the other float8 # flavors, fix it @@ -307,7 +281,7 @@ def cast_x_to_float8( x = x.to(autocast_dtype) if self.scaling_type_input is TensorScalingType.DELAYED: - scale_fn_name = self.recipe.scale_fn_name + scale_fn_name = self.config.delayed_scaling_config.scale_fn_name _maybe_initialize_amaxes_scales_for_float8_cast( x, self.fp8_amax_input, @@ -338,7 +312,7 @@ def cast_w_to_float8( if isinstance(self.weight, Float8Tensor): # cast by FSDP w_fp8 = self.weight else: - scale_fn_name = self.recipe.scale_fn_name + scale_fn_name = self.config.delayed_scaling_config.scale_fn_name _maybe_initialize_amaxes_scales_for_float8_cast( w, self.fp8_amax_weight, @@ -370,7 +344,7 @@ def cast_w_to_float8( def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor: if self.scaling_type_grad_output is TensorScalingType.DELAYED: - scale_fn_name = self.recipe.scale_fn_name + scale_fn_name = self.config.delayed_scaling_config.scale_fn_name y = NoopFwToFloat8E5M2Bw.apply( y, self.fp8_amax_grad_output, diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index e3a758a..c72b620 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -237,7 +237,7 @@ def inner_func(): fp8_grad_output_amax_history_stack[idx] = child.fp8_amax_history_grad_output x_dtypes.add(child.last_seen_input_dtype) - scale_fn_recipes.add(child.recipe.scale_fn_name) + scale_fn_recipes.add(child.config.delayed_scaling_config.scale_fn_name) # TODO This way to get the activation dtype is not ideal if len(x_dtypes) != 1: From b9b606e69a344494c1aa43ac5b917cc71825c9b1 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Thu, 25 Jul 2024 14:48:12 -0700 Subject: [PATCH 4/6] add per-gemm config to `Float8LinearConfig` (#334) Summary: Pull Request resolved: https://github.com/pytorch-labs/float8_experimental/pull/334 Previously the per-gemm configuration had to be hardcoded in library code. This PR exposes it to the top-level UX by adding a `Float8GemmConfig` field to `Float8LinearConfig`. Note that today the only supported configuration option is `use_fast_accum`. In the future, configuring output_dtype and whether to keep a gemm in higher precision would go here. Reviewed By: weifengpy Differential Revision: D60252069 fbshipit-source-id: bca34eb49e1bf046f937e32b11b2369b535d56e6 --- float8_experimental/__init__.py | 2 ++ float8_experimental/config.py | 19 +++++++++++++++++++ float8_experimental/float8_linear.py | 18 +++++++++++------- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/float8_experimental/__init__.py b/float8_experimental/__init__.py index 4c1f255..08c0ac4 100644 --- a/float8_experimental/__init__.py +++ b/float8_experimental/__init__.py @@ -6,6 +6,7 @@ # Lets define a few top level things here from float8_experimental.config import ( DelayedScalingConfig, + Float8GemmConfig, Float8LinearConfig, Float8TensorCastConfig, TensorScalingType, @@ -33,6 +34,7 @@ # configuration "DelayedScalingConfig", "TensorScalingType", + "Float8GemmConfig", "Float8LinearConfig", "Float8TensorCastConfig", # top level UX diff --git a/float8_experimental/config.py b/float8_experimental/config.py index ea088e3..6408ac7 100644 --- a/float8_experimental/config.py +++ b/float8_experimental/config.py @@ -53,6 +53,17 @@ def __post_init__(self): ), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now." +@dataclass(frozen=True) +class Float8GemmConfig: + """ + Configuration for a float8 gemm. + """ + + # If True, fast accumulation in lower precision is used. + # Note: this flag is currently a no-op if emulation is turned on. + use_fast_accum: bool = False + + @dataclass(frozen=True) class Float8LinearConfig: """ @@ -67,6 +78,14 @@ class Float8LinearConfig: cast_config_weight: Float8TensorCastConfig = Float8TensorCastConfig() cast_config_grad_output: Float8TensorCastConfig = Float8TensorCastConfig() + # + # Per-gemm configuration for gemms calculating `output`, `grad_input` and + # `grad_weight` + # + gemm_config_output: Float8GemmConfig = Float8GemmConfig(use_fast_accum=True) + gemm_config_grad_input: Float8GemmConfig = Float8GemmConfig() + gemm_config_grad_weight: Float8GemmConfig = Float8GemmConfig() + # # Per-linear configuration # diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 581f9f3..c598a93 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -168,24 +168,28 @@ def __init__(self, *args, **kwargs): self.create_buffers() - # TODO(future): user level configuration of gemms self.linear_mm_config = LinearMMConfig( - # input + # output ScaledMMConfig( emulate, - True if not emulate else False, + self.config.gemm_config_output.use_fast_accum, False, self.config.pad_inner_dim, ), - # weight + # grad_input ScaledMMConfig( emulate, - True if not emulate else False, + self.config.gemm_config_grad_input.use_fast_accum, + False, + self.config.pad_inner_dim, + ), + # grad_weight + ScaledMMConfig( + emulate, + self.config.gemm_config_grad_weight.use_fast_accum, False, self.config.pad_inner_dim, ), - # grad_output - ScaledMMConfig(emulate, False, False, self.config.pad_inner_dim), ) # Note: is_amax_initialized is not a buffer to avoid data dependent From 8352894b7ebffdcf3e96943165165b25d51a0d3e Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Thu, 25 Jul 2024 14:48:12 -0700 Subject: [PATCH 5/6] rename all variables to use input/weight/grad_output notation (#335) Summary: Pull Request resolved: https://github.com/pytorch-labs/float8_experimental/pull/335 In https://github.com/pytorch-labs/float8_experimental/pull/323 we changed the user facing variable notation from `x/w/dL_dY` to `input/weight/grad_output`. This PR follows up by changing most of the internal variables to also match the new notation, to reduce confusion. Reviewed By: weifengpy Differential Revision: D60252071 fbshipit-source-id: b91ec5b975df550962418eafc93f1904d64a3dd8 --- float8_experimental/float8_dynamic_utils.py | 4 +- float8_experimental/float8_linear.py | 76 ++++++++++--------- float8_experimental/float8_tensor.py | 55 +++++++------- float8_experimental/float8_tensor_parallel.py | 6 +- float8_experimental/fsdp_utils.py | 10 +-- float8_experimental/inference.py | 4 +- test/test_base.py | 24 +++--- test/test_compile.py | 4 +- test/test_dtensor.py | 11 ++- 9 files changed, 100 insertions(+), 94 deletions(-) diff --git a/float8_experimental/float8_dynamic_utils.py b/float8_experimental/float8_dynamic_utils.py index 9fe9d17..bfacd65 100644 --- a/float8_experimental/float8_dynamic_utils.py +++ b/float8_experimental/float8_dynamic_utils.py @@ -42,7 +42,7 @@ def backward(ctx, gradY): gradY_scale, e5m2_dtype, linear_mm_config=ctx.linear_mm_config, - gemm_input_role=GemmInputRole.DL_DY, + gemm_input_role=GemmInputRole.GRAD_OUTPUT, ) return fp8_tensor, None @@ -51,7 +51,7 @@ def cast_to_float8_e4m3_dynamic( inpt_tensor: torch.Tensor, linear_mm_config: LinearMMConfig, reduce_amax: bool = False, - gemm_input_role: GemmInputRole = GemmInputRole.X, + gemm_input_role: GemmInputRole = GemmInputRole.INPUT, ) -> Float8Tensor: if tensor_already_casted_to_fp8(inpt_tensor): return inpt_tensor diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index c598a93..42eeb86 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -125,7 +125,7 @@ def backward(ctx, go): fp8_scale_grad_output, e5m2_dtype, linear_mm_config=ctx.linear_mm_config, - gemm_input_role=GemmInputRole.DL_DY, + gemm_input_role=GemmInputRole.GRAD_OUTPUT, ) empty_grads = None, None, None, None, None, None return res, *empty_grads @@ -273,8 +273,8 @@ def convert_amax_buffer_to_float32(self): if self._buffers[key] is not None: self._buffers[key] = self._buffers[key].to(torch.float32) - def cast_x_to_float8( - self, x: torch.Tensor, is_amax_initialized: bool + def cast_input_to_float8( + self, input: torch.Tensor, is_amax_initialized: bool ) -> torch.Tensor: # Duplicate the autocast logic for F.linear, so that the output # of our module has the right original precision @@ -282,12 +282,12 @@ def cast_x_to_float8( # For now, hardcode to GPU's autocast dtype # if we need CPU support in the future, we can add it autocast_dtype = torch.get_autocast_gpu_dtype() - x = x.to(autocast_dtype) + input = input.to(autocast_dtype) if self.scaling_type_input is TensorScalingType.DELAYED: scale_fn_name = self.config.delayed_scaling_config.scale_fn_name _maybe_initialize_amaxes_scales_for_float8_cast( - x, + input, self.fp8_amax_input, self.fp8_amax_history_input, self.fp8_scale_input, @@ -296,29 +296,29 @@ def cast_x_to_float8( is_amax_initialized, reduce_amax=True, ) - x_fp8 = Float8Tensor.to_float8( - x, + input_fp8 = Float8Tensor.to_float8( + input, self.fp8_scale_input, e4m3_dtype, self.fp8_amax_input, linear_mm_config=self.linear_mm_config, - gemm_input_role=GemmInputRole.X, + gemm_input_role=GemmInputRole.INPUT, ) else: assert self.scaling_type_input is TensorScalingType.DYNAMIC - x_fp8 = cast_to_float8_e4m3_dynamic(x, self.linear_mm_config) - return x_fp8 + input_fp8 = cast_to_float8_e4m3_dynamic(input, self.linear_mm_config) + return input_fp8 - def cast_w_to_float8( - self, w: torch.Tensor, is_amax_initialized: bool + def cast_weight_to_float8( + self, weight: torch.Tensor, is_amax_initialized: bool ) -> torch.Tensor: if self.scaling_type_weight is TensorScalingType.DELAYED: if isinstance(self.weight, Float8Tensor): # cast by FSDP - w_fp8 = self.weight + weight_fp8 = self.weight else: scale_fn_name = self.config.delayed_scaling_config.scale_fn_name _maybe_initialize_amaxes_scales_for_float8_cast( - w, + weight, self.fp8_amax_weight, self.fp8_amax_history_weight, self.fp8_scale_weight, @@ -328,29 +328,31 @@ def cast_w_to_float8( reduce_amax=False, ) - w_fp8 = Float8Tensor.to_float8( - w, + weight_fp8 = Float8Tensor.to_float8( + weight, self.fp8_scale_weight, e4m3_dtype, self.fp8_amax_weight, linear_mm_config=self.linear_mm_config, - gemm_input_role=GemmInputRole.W, + gemm_input_role=GemmInputRole.WEIGHT, ) else: assert self.scaling_type_weight is TensorScalingType.DYNAMIC if isinstance(self.weight, Float8Tensor): # cast by FSDP - w_fp8 = self.weight + weight_fp8 = self.weight else: - w_fp8 = cast_to_float8_e4m3_dynamic( - self.weight, self.linear_mm_config, gemm_input_role=GemmInputRole.W + weight_fp8 = cast_to_float8_e4m3_dynamic( + self.weight, + self.linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, ) - return w_fp8 + return weight_fp8 - def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor: + def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor: if self.scaling_type_grad_output is TensorScalingType.DELAYED: scale_fn_name = self.config.delayed_scaling_config.scale_fn_name - y = NoopFwToFloat8E5M2Bw.apply( - y, + output = NoopFwToFloat8E5M2Bw.apply( + output, self.fp8_amax_grad_output, self.fp8_amax_history_grad_output, self.fp8_scale_grad_output, @@ -360,10 +362,10 @@ def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor: ) else: assert self.scaling_type_grad_output is TensorScalingType.DYNAMIC - y = cast_to_float8_e5m2_dynamic_bw(y, self.linear_mm_config) - return y + output = cast_to_float8_e5m2_dynamic_bw(output, self.linear_mm_config) + return output - def float8_pre_forward(self, x): + def float8_pre_forward(self, input): if not self.enable_pre_and_post_forward: return if ( @@ -374,7 +376,7 @@ def float8_pre_forward(self, x): raise AssertionError( "amaxes and scales not synced, please call `sync_float8_amax_and_scale_history` before forward" ) - self.last_seen_input_dtype = x.dtype + self.last_seen_input_dtype = input.dtype def float8_post_forward(self): if not self.enable_pre_and_post_forward: @@ -388,25 +390,25 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: if self.has_any_delayed_scaling: self.float8_pre_forward(input) - x_fp8 = self.cast_x_to_float8(input, self.is_amax_initialized) - w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized) + input_fp8 = self.cast_input_to_float8(input, self.is_amax_initialized) + weight_fp8 = self.cast_weight_to_float8(self.weight, self.is_amax_initialized) - y = torch.matmul(x_fp8, w_fp8.t()) + output = torch.matmul(input_fp8, weight_fp8.t()) - # Cast gradY to float8_e5m2 during backward - y = self.cast_y_to_float8_in_bw(y) + # Cast grad_output to float8_e5m2 during backward + output = self.cast_output_to_float8_in_bw(output) if self.bias is not None: - y = y + self.bias.to(y.dtype) + output = output + self.bias.to(output.dtype) if self.has_any_delayed_scaling: self.float8_post_forward() - return y + return output def scaling_repr(self): # add scaling settings without using too many characters - # example: "x:del,w:del,dldy:dyn" - return f"x:{self.scaling_type_input.short_str()},w:{self.scaling_type_weight.short_str()},dldy:{self.scaling_type_grad_output.short_str()}" + # example: "i:del,w:del,go:dyn" + return f"i:{self.scaling_type_input.short_str()},w:{self.scaling_type_weight.short_str()},go:{self.scaling_type_grad_output.short_str()}" def extra_repr(self): s = f'{super().extra_repr()}, scaling="{self.scaling_repr()}"' diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 3b108be..a46e7ce 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -27,21 +27,21 @@ # # There are three gemms in a forward + backward of a Linear layer: # -# 1. x @ w_t = y (forward pass) -# 2. dL_dY @ w = dL_dX (backward pass) -# 3. x_t @ dL_dY = dL_dW (backward pass) +# 1. input @ weight_t = output (forward pass) +# 2. grad_output @ weight = grad_input (backward pass) +# 3. input_t @ grad_output = grad_weight (backward pass) # # In the formulas above, there are: -# A. six input tensors (x, x_t, w, w_t, dL_dY, dL_dY_t). -# - Note that dL_dY_t is implied because of memory format requirements +# A. six input tensors (input, input_t, weight, weight_t, grad_output, grad_output_t). +# - Note that grad_output_t is implied because of memory format requirements # of float8 gemms -# B. three output tensors (y, dL_dX, dL_dW) +# B. three output tensors (output, grad_input, grad_weight) # # We want each input tensor, gemm, and output tensor to be configurable. # The state of this configuration today is: # # i. pairs of input tensors (non-t and t variants) have their scaling -# configurable via the scaling_type_{x_w_dL_dY} arguments to Float8Linear +# configurable via the scaling_type_* arguments to Float8Linear # ii. each gemm + output is configurable via ScaledMMConfig, which is not user facing # iii. LinearMMConfig is a container for the three ScaledMMConfig objects needed # to configure all three gemms, also not user facing @@ -60,11 +60,12 @@ # The object below is not user facing and exists for convenience, # to allow Float8Tensor to use -# the right config based on which gemm from `y`, `dL_dX`, `dL_dW` is +# the right config based on which gemm from gemms with outputs +# `output`, `grad_input`, `grad_weight` is # being called. LinearMMConfig = namedtuple( "LinearMMConfig", - ["y", "dL_dX", "dL_dW"], + ["output", "grad_input", "grad_weight"], defaults=[ ScaledMMConfig(False, True, False, False), ScaledMMConfig(False, False, False, False), @@ -81,9 +82,9 @@ class GemmInputRole(enum.Enum): gemm is performed. """ - X = "x" - W = "w" - DL_DY = "dL_dY" + INPUT = "input" + WEIGHT = "weight" + GRAD_OUTPUT = "grad_output" # choose which scaled_mm_config to use based on gemm inputs @@ -93,21 +94,21 @@ def choose_scaled_mm_config( b_role: GemmInputRole, b_linear_mm_config: LinearMMConfig, ): - if a_role is GemmInputRole.X and b_role is GemmInputRole.W: + if a_role is GemmInputRole.INPUT and b_role is GemmInputRole.WEIGHT: assert ( - a_linear_mm_config.y == b_linear_mm_config.y - ), f"linear_mm_config.y mismatch: {a_linear_mm_config.y} vs {b_linear_mm_config.y}" - return a_linear_mm_config.y - elif a_role is GemmInputRole.DL_DY and b_role is GemmInputRole.W: + a_linear_mm_config.output == b_linear_mm_config.output + ), f"linear_mm_config.output mismatch: {a_linear_mm_config.output} vs {b_linear_mm_config.output}" + return a_linear_mm_config.output + elif a_role is GemmInputRole.GRAD_OUTPUT and b_role is GemmInputRole.WEIGHT: assert ( - a_linear_mm_config.dL_dX == b_linear_mm_config.dL_dX - ), f"linear_mm_config.dL_dX mismatch: {a_linear_mm_config.dL_dX} vs {b_linear_mm_config.dL_dX}" - return a_linear_mm_config.dL_dX - elif a_role is GemmInputRole.DL_DY and b_role is GemmInputRole.X: + a_linear_mm_config.grad_input == b_linear_mm_config.grad_input + ), f"linear_mm_config.grad_input mismatch: {a_linear_mm_config.grad_input} vs {b_linear_mm_config.grad_input}" + return a_linear_mm_config.grad_input + elif a_role is GemmInputRole.GRAD_OUTPUT and b_role is GemmInputRole.INPUT: assert ( - a_linear_mm_config.dL_dW == b_linear_mm_config.dL_dW - ), f"linear_mm_config.dL_dW mismatch: {a_linear_mm_config.dL_dW} vs {b_linear_mm_config.dL_dW}" - return a_linear_mm_config.dL_dW + a_linear_mm_config.grad_weight == b_linear_mm_config.grad_weight + ), f"linear_mm_config.grad_weight mismatch: {a_linear_mm_config.grad_weight} vs {b_linear_mm_config.grad_weight}" + return a_linear_mm_config.grad_weight else: raise AssertionError(f"unexpected a_role {a_role} and b_role {b_role}") @@ -207,7 +208,7 @@ def forward( float8_dtype=e4m3_dtype, amax_buffer: Optional[torch.Tensor] = None, linear_mm_config: Optional[LinearMMConfig] = None, - gemm_input_role: Optional[GemmInputRole] = GemmInputRole.X, + gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, ): """Autograd enabled wrapper around to_fp8_no_autograd that will also populate the amax buffer. Args @@ -287,7 +288,7 @@ def __new__( scale: torch.Tensor, orig_dtype: torch.dtype, linear_mm_config: Optional[LinearMMConfig], - gemm_input_role: Optional[GemmInputRole] = GemmInputRole.X, + gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, ): assert ( scale.numel() == 1 @@ -348,7 +349,7 @@ def to_float8( float8_dtype: torch.dtype, amax_buffer: Optional[torch.Tensor] = None, linear_mm_config: Optional[LinearMMConfig] = None, - gemm_input_role: Optional[GemmInputRole] = GemmInputRole.X, + gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, ): """Converts a higher precision tensor to float8 in a differentiable way. diff --git a/float8_experimental/float8_tensor_parallel.py b/float8_experimental/float8_tensor_parallel.py index 1841553..99850ad 100644 --- a/float8_experimental/float8_tensor_parallel.py +++ b/float8_experimental/float8_tensor_parallel.py @@ -48,7 +48,7 @@ def _prepare_input_fn( input_tensor = cast_to_float8_e4m3_dynamic( input_tensor, mod.linear_mm_config, - gemm_input_role=GemmInputRole.X, + gemm_input_role=GemmInputRole.INPUT, ) # DTensor(Float8Tensor) # transform the input layouts to the desired layouts of ColwiseParallel @@ -101,7 +101,7 @@ def _prepare_input_fn( input_tensor = cast_to_float8_e4m3_dynamic( input_tensor, mod.linear_mm_config, - gemm_input_role=GemmInputRole.X, + gemm_input_role=GemmInputRole.INPUT, ) # DTensor(Float8Tensor) if input_layouts != desired_input_layouts: @@ -199,7 +199,7 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout): dt_inp = cast_to_float8_e4m3_dynamic( dt_inp, self.linear_mm_config, - gemm_input_role=GemmInputRole.X, + gemm_input_role=GemmInputRole.INPUT, ) # DTensor(Float8Tensor) if desired_layout is not None and input_layout != desired_layout: dt_inp = dt_inp.redistribute(placements=(desired_layout,)) diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py index eef9ec1..c124ee4 100644 --- a/float8_experimental/fsdp_utils.py +++ b/float8_experimental/fsdp_utils.py @@ -169,14 +169,14 @@ def fsdp_pre_all_gather(self, mesh): self._precomputed_scale, torch.float8_e4m3fn, linear_mm_config=self._linear_mm_config, - gemm_input_role=GemmInputRole.W, + gemm_input_role=GemmInputRole.WEIGHT, ) else: float8_tensor = cast_to_float8_e4m3_dynamic( self._tensor, self._linear_mm_config, reduce_amax=True, - gemm_input_role=GemmInputRole.W, + gemm_input_role=GemmInputRole.WEIGHT, ) return (float8_tensor._data,), (float8_tensor._scale,) @@ -199,7 +199,7 @@ def fsdp_post_all_gather( scale, param_dtype, self._linear_mm_config, - gemm_input_role=GemmInputRole.W, + gemm_input_role=GemmInputRole.WEIGHT, ), (data,) @@ -362,7 +362,7 @@ def fsdp_pre_all_gather(self, mesh): e4m3_dtype, self._amax_buffer, self._linear_mm_config, - gemm_input_role=GemmInputRole.W, + gemm_input_role=GemmInputRole.WEIGHT, ) return (float8_tensor._data,), (float8_tensor._scale,) @@ -385,5 +385,5 @@ def fsdp_post_all_gather( scale, param_dtype, self._linear_mm_config, - gemm_input_role=GemmInputRole.W, + gemm_input_role=GemmInputRole.WEIGHT, ), (data,) diff --git a/float8_experimental/inference.py b/float8_experimental/inference.py index d24fedb..0c10589 100644 --- a/float8_experimental/inference.py +++ b/float8_experimental/inference.py @@ -132,7 +132,7 @@ def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None: scale, dtype, self.linear_mm_config, - gemm_input_role=GemmInputRole.W, + gemm_input_role=GemmInputRole.WEIGHT, ) self.weight = nn.Parameter(quantized_weight) self.weight.requires_grad = False @@ -205,7 +205,7 @@ def cast_to_float8_e4m3_inference( scale, e4m3_dtype, linear_mm_config=linear_mm_config, - gemm_input_role=GemmInputRole.X, + gemm_input_role=GemmInputRole.INPUT, ) diff --git a/test/test_base.py b/test/test_base.py index 4d36ad1..ffc8d0c 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -395,7 +395,7 @@ def test_repr(self): config=config, ) s = m.__repr__() - assert "x:dyn,w:del,dldy:dyn" in s + assert "i:dyn,w:del,go:dyn" in s class TestScaledMM: @@ -464,18 +464,18 @@ def test_different_configs_error(self): x_scale, fp8_dtype, linear_mm_config=linear_config_a, - gemm_input_role=GemmInputRole.X, + gemm_input_role=GemmInputRole.INPUT, ) b = Float8Tensor.to_float8( x_fp32, x_scale, fp8_dtype, linear_mm_config=linear_config_b, - gemm_input_role=GemmInputRole.W, + gemm_input_role=GemmInputRole.WEIGHT, ) with pytest.raises( AssertionError, - match="linear_mm_config.y mismatch", + match="linear_mm_config.output mismatch", ): a @ b @@ -499,10 +499,10 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): b_scale = tensor_to_scale(b, input_dtype).float() a_fp8 = Float8Tensor.to_float8( - a, a_scale, input_dtype, gemm_input_role=GemmInputRole.X + a, a_scale, input_dtype, gemm_input_role=GemmInputRole.INPUT ) b_fp8 = Float8Tensor.to_float8( - b, b_scale, input_dtype, gemm_input_role=GemmInputRole.W + b, b_scale, input_dtype, gemm_input_role=GemmInputRole.WEIGHT ) with pytest.raises( @@ -523,14 +523,14 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): a_scale, input_dtype, linear_mm_config=pad_config, - gemm_input_role=GemmInputRole.X, + gemm_input_role=GemmInputRole.INPUT, ) b_fp8 = Float8Tensor.to_float8( b, b_scale, input_dtype, linear_mm_config=pad_config, - gemm_input_role=GemmInputRole.W, + gemm_input_role=GemmInputRole.WEIGHT, ) out_padded = a_fp8 @ b_fp8 out_padded.to(compare_type) @@ -546,14 +546,14 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): a_scale, input_dtype, linear_mm_config=emulated_config, - gemm_input_role=GemmInputRole.X, + gemm_input_role=GemmInputRole.INPUT, ) b_fp8 = Float8Tensor.to_float8( b, b_scale, input_dtype, linear_mm_config=emulated_config, - gemm_input_role=GemmInputRole.W, + gemm_input_role=GemmInputRole.WEIGHT, ) out_emualted = a_fp8 @ b_fp8 out_emualted.to(compare_type) @@ -606,8 +606,8 @@ def test_swap_root_linear(self): config = Float8LinearConfig(emulate=emulate) module = convert_to_float8_training(module, config=config) self.assertIsInstance(module, Float8Linear) - self.assertEqual(module.linear_mm_config.y.emulate, emulate) - self.assertEqual(module.linear_mm_config.y.emulate, emulate) + self.assertEqual(module.linear_mm_config.output.emulate, emulate) + self.assertEqual(module.linear_mm_config.output.emulate, emulate) def test_swap_root_linear_with_children_raises(self): for emulate in [True, False]: diff --git a/test/test_compile.py b/test/test_compile.py index a457dd8..e7b5285 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -256,9 +256,9 @@ def test_float8_graph_output(self): type(y_compiled._orig_dtype) ) assert isinstance( - y_compiled._linear_mm_config.y.emulate, bool + y_compiled._linear_mm_config.output.emulate, bool ), "Float8Tensor._emulate should be a bool but got {}".format( - type(y_compiled._linear_mm_config.y.emulate) + type(y_compiled._linear_mm_config.output.emulate) ) diff --git a/test/test_dtensor.py b/test/test_dtensor.py index 0b294d8..eeca6df 100644 --- a/test/test_dtensor.py +++ b/test/test_dtensor.py @@ -87,10 +87,10 @@ def test_scaled_mm(mesh: DeviceMesh, size=16): y_scale = tensor_to_scale(y_fp32, fp8_dtype).float() x_fp8 = Float8Tensor.to_float8( - x_fp32, x_scale, fp8_dtype, gemm_input_role=GemmInputRole.X + x_fp32, x_scale, fp8_dtype, gemm_input_role=GemmInputRole.INPUT ) y_fp8 = Float8Tensor.to_float8( - y_fp32, y_scale, fp8_dtype, gemm_input_role=GemmInputRole.W + y_fp32, y_scale, fp8_dtype, gemm_input_role=GemmInputRole.WEIGHT ) dist_x_fp8 = DTensor.from_local(x_fp8, mesh, [lhs_placement], run_check=False) @@ -164,10 +164,13 @@ def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): dist_target = distribute_tensor(target, mesh, [Shard(0)]) dist_x_fp8 = Float8Tensor.to_float8( - dist_x_fp32, dist_x_scale, fp8_dtype, gemm_input_role=GemmInputRole.X + dist_x_fp32, dist_x_scale, fp8_dtype, gemm_input_role=GemmInputRole.INPUT ) dist_weight_fp8 = Float8Tensor.to_float8( - dist_wight_fp32, dist_weight_scale, fp8_dtype, gemm_input_role=GemmInputRole.W + dist_wight_fp32, + dist_weight_scale, + fp8_dtype, + gemm_input_role=GemmInputRole.WEIGHT, ) out = torch.nn.functional.linear(dist_x_fp8, dist_weight_fp8) From 0aca10aced1c4b3abdf00960d83316732cb08ed1 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Thu, 25 Jul 2024 14:48:12 -0700 Subject: [PATCH 6/6] rename TensorScalingType->ScalingType, Float8TensorCastConfig->CastConfig (#337) Summary: Pull Request resolved: https://github.com/pytorch-labs/float8_experimental/pull/337 old: TensorScalingType, Float8TensorCastConfig new: ScalingType, CastConfig reason: make code more readable, previous names were too verbose, the `tensor` information is implied Reviewed By: weifengpy Differential Revision: D60252070 fbshipit-source-id: c3e5bb78e2f7c66fc73975a21e7fa1db210b2c18 --- README.md | 10 +-- benchmarks/bench_linear_float8.py | 20 ++--- benchmarks/bench_multi_gpu.py | 14 +-- benchmarks/profile_linear_float8.py | 20 ++--- float8_experimental/__init__.py | 8 +- float8_experimental/config.py | 17 ++-- float8_experimental/float8_linear.py | 26 +++--- float8_experimental/float8_linear_utils.py | 8 +- float8_experimental/float8_tensor_parallel.py | 6 +- float8_experimental/fsdp_utils.py | 5 +- test/test_base.py | 49 ++++------ test/test_compile.py | 90 +++++++------------ test/test_fsdp.py | 12 +-- test/test_fsdp2/test_fsdp2.py | 42 +++------ test/test_fsdp2/test_fsdp2_common.py | 4 +- test/test_fsdp_compile.py | 14 +-- test/test_inference_flows.py | 2 +- test/test_numerics_integration.py | 26 +++--- 18 files changed, 143 insertions(+), 230 deletions(-) diff --git a/README.md b/README.md index 6c14b8e..642529f 100644 --- a/README.md +++ b/README.md @@ -85,7 +85,7 @@ This is theoretically the most performant recipe as it minimizes memory reads. from float8_experimental import ( convert_to_float8_training, sync_float8_amax_and_scale_history, - TensorScalingType, + ScalingType, ) # create model @@ -95,13 +95,13 @@ m = Model(...) # gated with config.enable_amax_init and # config.enable_pre_and_post_forward are needed for # autocast + compile + FSDP + float8 to work -from float8_experimental import Float8LinearConfig, TensorScalingType, Float8TensorCastConfig +from float8_experimental import Float8LinearConfig, ScalingType, CastConfig config = Float8LinearConfig( enable_amax_init = False, # only needed for autocast + compile + FSDP + float8 delayed enable_pre_and_post_forward, False # only needed for autocast + compile + FSDP + float8 delayed - cast_config_input=Float8TensorCastConfig(scaling_type=TensorScalingType.DELAYED), - cast_config_weight=Float8TensorCastConfig(scaling_type=TensorScalingType.DELAYED), - cast_config_grad_output=Float8TensorCastConfig(scaling_type=TensorScalingType.DELAYED), + cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), ) # convert all `torch.nn.Linear` modules to `Float8Linear`, specifying scaling diff --git a/benchmarks/bench_linear_float8.py b/benchmarks/bench_linear_float8.py index 2780600..0bbf116 100644 --- a/benchmarks/bench_linear_float8.py +++ b/benchmarks/bench_linear_float8.py @@ -14,11 +14,7 @@ import torch import torch.utils.benchmark as benchmark -from float8_experimental.config import ( - Float8LinearConfig, - Float8TensorCastConfig, - TensorScalingType, -) +from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType from float8_experimental.float8_linear import Float8Linear from float8_experimental.float8_linear_utils import ( linear_requires_sync, @@ -107,15 +103,13 @@ def main( device = "cuda" print(f"Compile is set to | {compile}") - scaling_type_input = TensorScalingType(scaling_type_input) - scaling_type_weight = TensorScalingType(scaling_type_weight) - scaling_type_grad_output = TensorScalingType(scaling_type_grad_output) + scaling_type_input = ScalingType(scaling_type_input) + scaling_type_weight = ScalingType(scaling_type_weight) + scaling_type_grad_output = ScalingType(scaling_type_grad_output) config = Float8LinearConfig( - cast_config_input=Float8TensorCastConfig(scaling_type=scaling_type_input), - cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight), - cast_config_grad_output=Float8TensorCastConfig( - scaling_type=scaling_type_grad_output - ), + cast_config_input=CastConfig(scaling_type=scaling_type_input), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), ) # LLaMa 2 70B single-node weight shapes diff --git a/benchmarks/bench_multi_gpu.py b/benchmarks/bench_multi_gpu.py index 5cb5223..a741dec 100644 --- a/benchmarks/bench_multi_gpu.py +++ b/benchmarks/bench_multi_gpu.py @@ -14,11 +14,7 @@ import torch.multiprocessing as mp import torch.nn as nn import torch.utils.benchmark as benchmark -from float8_experimental.config import ( - Float8LinearConfig, - Float8TensorCastConfig, - TensorScalingType, -) +from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType from float8_experimental.float8_linear_utils import ( convert_to_float8_training, sync_float8_amax_and_scale_history, @@ -33,11 +29,9 @@ lr = 0.01 config = Float8LinearConfig( - cast_config_input=Float8TensorCastConfig(scaling_type=TensorScalingType.DELAYED), - cast_config_weight=Float8TensorCastConfig(scaling_type=TensorScalingType.DELAYED), - cast_config_grad_output=Float8TensorCastConfig( - scaling_type=TensorScalingType.DELAYED - ), + cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), ) diff --git a/benchmarks/profile_linear_float8.py b/benchmarks/profile_linear_float8.py index e90c5b3..716ceed 100644 --- a/benchmarks/profile_linear_float8.py +++ b/benchmarks/profile_linear_float8.py @@ -18,11 +18,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from float8_experimental.config import ( - Float8LinearConfig, - Float8TensorCastConfig, - TensorScalingType, -) +from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType from float8_experimental.float8_linear_utils import ( convert_to_float8_training, linear_requires_sync, @@ -217,15 +213,13 @@ def main( assert model_type in ("linear", "ln_linear", "norm_ffn_norm"), "unsupported" assert dtype_filter in ("both", "float8", "bfloat16") - scaling_type_input = TensorScalingType(scaling_type_input) - scaling_type_weight = TensorScalingType(scaling_type_weight) - scaling_type_grad_output = TensorScalingType(scaling_type_grad_output) + scaling_type_input = ScalingType(scaling_type_input) + scaling_type_weight = ScalingType(scaling_type_weight) + scaling_type_grad_output = ScalingType(scaling_type_grad_output) config = Float8LinearConfig( - cast_config_input=Float8TensorCastConfig(scaling_type=scaling_type_input), - cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight), - cast_config_grad_output=Float8TensorCastConfig( - scaling_type=scaling_type_grad_output - ), + cast_config_input=CastConfig(scaling_type=scaling_type_input), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), ) scaling_repr = "_".join( [ diff --git a/float8_experimental/__init__.py b/float8_experimental/__init__.py index 08c0ac4..8fd8476 100644 --- a/float8_experimental/__init__.py +++ b/float8_experimental/__init__.py @@ -5,11 +5,11 @@ # LICENSE file in the root directory of this source tree. # Lets define a few top level things here from float8_experimental.config import ( + CastConfig, DelayedScalingConfig, Float8GemmConfig, Float8LinearConfig, - Float8TensorCastConfig, - TensorScalingType, + ScalingType, ) from float8_experimental.float8_linear import Float8Linear from float8_experimental.float8_linear_utils import ( @@ -33,10 +33,10 @@ __all__ = [ # configuration "DelayedScalingConfig", - "TensorScalingType", + "ScalingType", "Float8GemmConfig", "Float8LinearConfig", - "Float8TensorCastConfig", + "CastConfig", # top level UX "convert_to_float8_training", "linear_requires_sync", diff --git a/float8_experimental/config.py b/float8_experimental/config.py index 6408ac7..5d1bf9f 100644 --- a/float8_experimental/config.py +++ b/float8_experimental/config.py @@ -8,25 +8,26 @@ from dataclasses import dataclass -class TensorScalingType(enum.Enum): +# TODO(future): consider renaming to ScalingType +class ScalingType(enum.Enum): DELAYED = "delayed" DYNAMIC = "dynamic" def short_str(self): - if self is TensorScalingType.DELAYED: + if self is ScalingType.DELAYED: return "del" else: - assert self is TensorScalingType.DYNAMIC + assert self is ScalingType.DYNAMIC return "dyn" @dataclass(frozen=True) -class Float8TensorCastConfig: +class CastConfig: """ Configuration for casting a single tensor to float8 """ - scaling_type: TensorScalingType = TensorScalingType.DYNAMIC + scaling_type: ScalingType = ScalingType.DYNAMIC @dataclass(frozen=True) @@ -74,9 +75,9 @@ class Float8LinearConfig: # # Per-tensor configuration for `input`, `weight`, `grad_output` # - cast_config_input: Float8TensorCastConfig = Float8TensorCastConfig() - cast_config_weight: Float8TensorCastConfig = Float8TensorCastConfig() - cast_config_grad_output: Float8TensorCastConfig = Float8TensorCastConfig() + cast_config_input: CastConfig = CastConfig() + cast_config_weight: CastConfig = CastConfig() + cast_config_grad_output: CastConfig = CastConfig() # # Per-gemm configuration for gemms calculating `output`, `grad_input` and diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 42eeb86..fd76a8e 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -14,7 +14,7 @@ import torch -from float8_experimental.config import Float8LinearConfig, TensorScalingType +from float8_experimental.config import Float8LinearConfig, ScalingType from float8_experimental.float8_dynamic_utils import ( cast_to_float8_e4m3_dynamic, @@ -159,9 +159,9 @@ def __init__(self, *args, **kwargs): self.scaling_type_grad_output = config.cast_config_grad_output.scaling_type # Convenience flag to skip code related to delayed scaling self.has_any_delayed_scaling = ( - self.scaling_type_input is TensorScalingType.DELAYED - or self.scaling_type_weight is TensorScalingType.DELAYED - or self.scaling_type_grad_output is TensorScalingType.DELAYED + self.scaling_type_input is ScalingType.DELAYED + or self.scaling_type_weight is ScalingType.DELAYED + or self.scaling_type_grad_output is ScalingType.DELAYED ) self.config = config @@ -284,7 +284,7 @@ def cast_input_to_float8( autocast_dtype = torch.get_autocast_gpu_dtype() input = input.to(autocast_dtype) - if self.scaling_type_input is TensorScalingType.DELAYED: + if self.scaling_type_input is ScalingType.DELAYED: scale_fn_name = self.config.delayed_scaling_config.scale_fn_name _maybe_initialize_amaxes_scales_for_float8_cast( input, @@ -305,14 +305,14 @@ def cast_input_to_float8( gemm_input_role=GemmInputRole.INPUT, ) else: - assert self.scaling_type_input is TensorScalingType.DYNAMIC + assert self.scaling_type_input is ScalingType.DYNAMIC input_fp8 = cast_to_float8_e4m3_dynamic(input, self.linear_mm_config) return input_fp8 def cast_weight_to_float8( self, weight: torch.Tensor, is_amax_initialized: bool ) -> torch.Tensor: - if self.scaling_type_weight is TensorScalingType.DELAYED: + if self.scaling_type_weight is ScalingType.DELAYED: if isinstance(self.weight, Float8Tensor): # cast by FSDP weight_fp8 = self.weight else: @@ -337,7 +337,7 @@ def cast_weight_to_float8( gemm_input_role=GemmInputRole.WEIGHT, ) else: - assert self.scaling_type_weight is TensorScalingType.DYNAMIC + assert self.scaling_type_weight is ScalingType.DYNAMIC if isinstance(self.weight, Float8Tensor): # cast by FSDP weight_fp8 = self.weight else: @@ -349,7 +349,7 @@ def cast_weight_to_float8( return weight_fp8 def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor: - if self.scaling_type_grad_output is TensorScalingType.DELAYED: + if self.scaling_type_grad_output is ScalingType.DELAYED: scale_fn_name = self.config.delayed_scaling_config.scale_fn_name output = NoopFwToFloat8E5M2Bw.apply( output, @@ -361,7 +361,7 @@ def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor: self.linear_mm_config, ) else: - assert self.scaling_type_grad_output is TensorScalingType.DYNAMIC + assert self.scaling_type_grad_output is ScalingType.DYNAMIC output = cast_to_float8_e5m2_dynamic_bw(output, self.linear_mm_config) return output @@ -448,7 +448,7 @@ def from_float( # 2. buffers need to be already created for the delayed scaling version # of the weight wrapper to be initialized if config.enable_fsdp_float8_all_gather: - if config.cast_config_weight.scaling_type is TensorScalingType.DYNAMIC: + if config.cast_config_weight.scaling_type is ScalingType.DYNAMIC: new_mod.weight = torch.nn.Parameter( WeightWithDynamicFloat8CastTensor( new_mod.weight, @@ -456,9 +456,7 @@ def from_float( ) ) else: - assert ( - config.cast_config_weight.scaling_type is TensorScalingType.DELAYED - ) + assert config.cast_config_weight.scaling_type is ScalingType.DELAYED new_mod.weight = torch.nn.Parameter( WeightWithDelayedFloat8CastTensor( new_mod.weight, diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index c72b620..7fffcde 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -9,7 +9,7 @@ import torch import torch.distributed as dist import torch.nn as nn -from float8_experimental.config import Float8LinearConfig, TensorScalingType +from float8_experimental.config import Float8LinearConfig, ScalingType from float8_experimental.float8_linear import Float8Linear from float8_experimental.float8_utils import ( @@ -27,9 +27,9 @@ def linear_requires_sync(config: Float8LinearConfig): """Returns whether the given linear_type requires sync before forward.""" return any( [ - config.cast_config_input.scaling_type is TensorScalingType.DELAYED, - config.cast_config_weight.scaling_type is TensorScalingType.DELAYED, - config.cast_config_grad_output.scaling_type is TensorScalingType.DELAYED, + config.cast_config_input.scaling_type is ScalingType.DELAYED, + config.cast_config_weight.scaling_type is ScalingType.DELAYED, + config.cast_config_grad_output.scaling_type is ScalingType.DELAYED, ] ) diff --git a/float8_experimental/float8_tensor_parallel.py b/float8_experimental/float8_tensor_parallel.py index 99850ad..eea7376 100644 --- a/float8_experimental/float8_tensor_parallel.py +++ b/float8_experimental/float8_tensor_parallel.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -from float8_experimental.config import TensorScalingType +from float8_experimental.config import ScalingType from float8_experimental.float8_dynamic_utils import ( cast_to_float8_e4m3_dynamic, cast_to_float8_e5m2_dynamic_bw, @@ -28,8 +28,8 @@ def _float8_linear_supports_float8_allgather(m): # TODO(future): add support for delayed scaling for activations # and gradients return ( - m.scaling_type_input == TensorScalingType.DYNAMIC - and m.scaling_type_grad_output == TensorScalingType.DYNAMIC + m.scaling_type_input == ScalingType.DYNAMIC + and m.scaling_type_grad_output == ScalingType.DYNAMIC ) diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py index c124ee4..5fbefc9 100644 --- a/float8_experimental/fsdp_utils.py +++ b/float8_experimental/fsdp_utils.py @@ -33,13 +33,12 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: optim.step() precompute_float8_dynamic_scale_for_fsdp(model) """ - from float8_experimental.config import TensorScalingType + from float8_experimental.config import ScalingType from float8_experimental.float8_linear import Float8Linear from torch.distributed._tensor import DTensor if any( - isinstance(m, Float8Linear) - and m.scaling_type_weight is TensorScalingType.DELAYED + isinstance(m, Float8Linear) and m.scaling_type_weight is ScalingType.DELAYED for m in module.modules() ): raise NotImplementedError("Only supports delayed scaling") diff --git a/test/test_base.py b/test/test_base.py index ffc8d0c..2f7c717 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -16,11 +16,8 @@ import torch import torch.nn as nn -from float8_experimental.config import ( - Float8LinearConfig, - Float8TensorCastConfig, - TensorScalingType, -) +from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType +from float8_experimental.float8_dynamic_utils import cast_to_float8_e4m3_dynamic from float8_experimental.float8_linear import Float8Linear from float8_experimental.float8_linear_utils import ( convert_to_float8_training, @@ -183,15 +180,15 @@ def _test_linear_impl( amax_buffer_names = [] amax_history_buffer_names = [] scale_buffer_names = [] - if config.cast_config_input.scaling_type is TensorScalingType.DELAYED: + if config.cast_config_input.scaling_type is ScalingType.DELAYED: amax_buffer_names.append("fp8_amax_input") amax_history_buffer_names.append("fp8_amax_history_input") scale_buffer_names.append("fp8_scale_input") - if config.cast_config_weight.scaling_type is TensorScalingType.DELAYED: + if config.cast_config_weight.scaling_type is ScalingType.DELAYED: amax_buffer_names.append("fp8_amax_weight") amax_history_buffer_names.append("fp8_amax_history_weight") scale_buffer_names.append("fp8_scale_weight") - if config.cast_config_grad_output.scaling_type is TensorScalingType.DELAYED: + if config.cast_config_grad_output.scaling_type is ScalingType.DELAYED: amax_buffer_names.append("fp8_amax_grad_output") amax_history_buffer_names.append("fp8_amax_history_grad_output") scale_buffer_names.append("fp8_scale_grad_output") @@ -223,14 +220,14 @@ def _test_linear_impl( @pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True]) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @pytest.mark.parametrize( - "scaling_type_input", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC] ) @pytest.mark.parametrize( - "scaling_type_weight", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC] ) @pytest.mark.parametrize( "scaling_type_grad_output", - [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC], + [ScalingType.DELAYED, ScalingType.DYNAMIC], ) @pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32]) @pytest.mark.parametrize("linear_bias", [False, True]) @@ -239,9 +236,9 @@ def test_linear( self, x_shape, emulate: bool, - scaling_type_input: TensorScalingType, - scaling_type_weight: TensorScalingType, - scaling_type_grad_output: TensorScalingType, + scaling_type_input: ScalingType, + scaling_type_weight: ScalingType, + scaling_type_grad_output: ScalingType, linear_dtype: torch.dtype, linear_bias: bool, ): @@ -257,11 +254,9 @@ def test_linear( x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype) config = Float8LinearConfig( - cast_config_input=Float8TensorCastConfig(scaling_type=scaling_type_input), - cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight), - cast_config_grad_output=Float8TensorCastConfig( - scaling_type=scaling_type_grad_output - ), + cast_config_input=CastConfig(scaling_type=scaling_type_input), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), emulate=emulate, ) self._test_linear_impl( @@ -292,15 +287,9 @@ def test_autocast_outputs( m_ref = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) config = Float8LinearConfig( - cast_config_input=Float8TensorCastConfig( - scaling_type=TensorScalingType.DELAYED - ), - cast_config_weight=Float8TensorCastConfig( - scaling_type=TensorScalingType.DELAYED - ), - cast_config_grad_output=Float8TensorCastConfig( - scaling_type=TensorScalingType.DELAYED - ), + cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), emulate=emulate, ) m = Float8Linear.from_float(copy.deepcopy(m_ref), config) @@ -385,9 +374,7 @@ def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool): def test_repr(self): m = nn.Linear(32, 16) config = Float8LinearConfig( - cast_config_weight=Float8TensorCastConfig( - scaling_type=TensorScalingType.DELAYED - ), + cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), emulate=True, ) m = Float8Linear.from_float( diff --git a/test/test_compile.py b/test/test_compile.py index e7b5285..a71b879 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -13,11 +13,7 @@ import torch import torch.nn as nn -from float8_experimental.config import ( - Float8LinearConfig, - Float8TensorCastConfig, - TensorScalingType, -) +from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType from float8_experimental.float8_linear import Float8Linear from float8_experimental.float8_linear_utils import ( convert_to_float8_training, @@ -67,13 +63,13 @@ def _test_compile_base( @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize( - "scaling_type_input", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC] ) @pytest.mark.parametrize( - "scaling_type_weight", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC] ) @pytest.mark.parametrize( - "scaling_type_grad_output", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC] ) @pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) @@ -81,18 +77,16 @@ def _test_compile_base( def test_eager_only( fullgraph, emulate: bool, - scaling_type_input: TensorScalingType, - scaling_type_weight: TensorScalingType, - scaling_type_grad_output: TensorScalingType, + scaling_type_input: ScalingType, + scaling_type_weight: ScalingType, + scaling_type_grad_output: ScalingType, dtype: torch.dtype, ): torch._dynamo.reset() config = Float8LinearConfig( - cast_config_input=Float8TensorCastConfig(scaling_type=scaling_type_input), - cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight), - cast_config_grad_output=Float8TensorCastConfig( - scaling_type=scaling_type_grad_output - ), + cast_config_input=CastConfig(scaling_type=scaling_type_input), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), emulate=emulate, ) _test_compile_base( @@ -106,31 +100,29 @@ def test_eager_only( @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True]) @pytest.mark.parametrize( - "scaling_type_input", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC] ) @pytest.mark.parametrize( - "scaling_type_weight", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC] ) @pytest.mark.parametrize( - "scaling_type_grad_output", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC] ) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_aot_eager( fullgraph, emulate: bool, - scaling_type_input: TensorScalingType, - scaling_type_weight: TensorScalingType, - scaling_type_grad_output: TensorScalingType, + scaling_type_input: ScalingType, + scaling_type_weight: ScalingType, + scaling_type_grad_output: ScalingType, dtype: torch.dtype, ): torch._dynamo.reset() config = Float8LinearConfig( - cast_config_input=Float8TensorCastConfig(scaling_type=scaling_type_input), - cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight), - cast_config_grad_output=Float8TensorCastConfig( - scaling_type=scaling_type_grad_output - ), + cast_config_input=CastConfig(scaling_type=scaling_type_input), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), emulate=emulate, ) _test_compile_base( @@ -144,31 +136,29 @@ def test_aot_eager( @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize("emulate", [False]) @pytest.mark.parametrize( - "scaling_type_input", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC] ) @pytest.mark.parametrize( - "scaling_type_weight", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC] ) @pytest.mark.parametrize( - "scaling_type_grad_output", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC] ) @unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA not available") @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_inductor( fullgraph, emulate: bool, - scaling_type_input: TensorScalingType, - scaling_type_weight: TensorScalingType, - scaling_type_grad_output: TensorScalingType, + scaling_type_input: ScalingType, + scaling_type_weight: ScalingType, + scaling_type_grad_output: ScalingType, dtype: torch.dtype, ): torch._dynamo.reset() config = Float8LinearConfig( - cast_config_input=Float8TensorCastConfig(scaling_type=scaling_type_input), - cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight), - cast_config_grad_output=Float8TensorCastConfig( - scaling_type=scaling_type_grad_output - ), + cast_config_input=CastConfig(scaling_type=scaling_type_input), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), emulate=emulate, ) _test_compile_base( @@ -270,15 +260,9 @@ def test_sync_amax_func(): nn.Linear(16, 32, bias=True), nn.ReLU(), nn.Linear(32, 16, bias=True) ) config = Float8LinearConfig( - cast_config_input=Float8TensorCastConfig( - scaling_type=TensorScalingType.DELAYED - ), - cast_config_weight=Float8TensorCastConfig( - scaling_type=TensorScalingType.DELAYED - ), - cast_config_grad_output=Float8TensorCastConfig( - scaling_type=TensorScalingType.DELAYED - ), + cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), ) float8_mod = convert_to_float8_training( module, @@ -314,15 +298,9 @@ def test_sync_amax_func_cuda_graph_success(): nn.Linear(16, 32, bias=True), nn.ReLU(), nn.Linear(32, 16, bias=True) ).to("cuda") config = Float8LinearConfig( - cast_config_input=Float8TensorCastConfig( - scaling_type=TensorScalingType.DELAYED - ), - cast_config_weight=Float8TensorCastConfig( - scaling_type=TensorScalingType.DELAYED - ), - cast_config_grad_output=Float8TensorCastConfig( - scaling_type=TensorScalingType.DELAYED - ), + cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), ) convert_to_float8_training( my_module, diff --git a/test/test_fsdp.py b/test/test_fsdp.py index c7f86cc..f5be23b 100644 --- a/test/test_fsdp.py +++ b/test/test_fsdp.py @@ -21,11 +21,7 @@ import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn -from float8_experimental.config import ( - Float8LinearConfig, - Float8TensorCastConfig, - TensorScalingType, -) +from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType from float8_experimental.float8_linear_utils import ( convert_to_float8_training, linear_requires_sync, @@ -78,12 +74,10 @@ def fsdp_main(rank, world_size, args): model_fp8 = copy.deepcopy(model) scaling_type_weight = ( - TensorScalingType.DYNAMIC - if use_weight_dynamic_scaling - else TensorScalingType.DELAYED + ScalingType.DYNAMIC if use_weight_dynamic_scaling else ScalingType.DELAYED ) config = Float8LinearConfig( - cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), # TODO(future): delete this arg as it's always False emulate=False, ) diff --git a/test/test_fsdp2/test_fsdp2.py b/test/test_fsdp2/test_fsdp2.py index 6d5719a..266bd6d 100644 --- a/test/test_fsdp2/test_fsdp2.py +++ b/test/test_fsdp2/test_fsdp2.py @@ -8,11 +8,7 @@ import torch._dynamo.testing import torch.distributed as dist import torch.nn as nn -from float8_experimental.config import ( - Float8LinearConfig, - Float8TensorCastConfig, - TensorScalingType, -) +from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType from float8_experimental.float8_linear_utils import convert_to_float8_training from float8_experimental.fsdp_utils import WeightWithDynamicFloat8CastTensor from test_fsdp2_common import check_parity_bf16_mp, check_parity_no_mp @@ -86,8 +82,8 @@ def test_transformer_parity(self): "enable_fsdp_float8_all_gather": [False, True], "precompute": [False, True], "scaling_type_weight": [ - TensorScalingType.DYNAMIC, - TensorScalingType.DELAYED, + ScalingType.DYNAMIC, + ScalingType.DELAYED, ], "compile_transformer_block": [False, True], }, @@ -98,12 +94,12 @@ def _test_transformer_parity( self, enable_fsdp_float8_all_gather: bool, precompute: bool, - scaling_type_weight: TensorScalingType, + scaling_type_weight: ScalingType, compile_transformer_block: bool, ): if not enable_fsdp_float8_all_gather and precompute: return - elif scaling_type_weight is TensorScalingType.DELAYED and precompute: + elif scaling_type_weight is ScalingType.DELAYED and precompute: return # NOTE: Weight-tying does not compose with fp8 all-gather because the @@ -114,7 +110,7 @@ def _test_transformer_parity( module = self.init_transformer(weight_tying=weight_tying).cuda() ref_module = copy.deepcopy(module) float8_linear_config1 = Float8LinearConfig( - cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), ) convert_to_float8_training( ref_module, @@ -126,7 +122,7 @@ def _test_transformer_parity( ref_module.layers.register_module(layer_id, transformer_block) float8_linear_config2 = Float8LinearConfig( enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, - cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), ) convert_to_float8_training( module, @@ -416,20 +412,16 @@ def test_fp32_fp8_single_module_parity(self): """ choices = itertools.product( [False, True], - [TensorScalingType.DYNAMIC, TensorScalingType.DELAYED], + [ScalingType.DYNAMIC, ScalingType.DELAYED], ) for enable_fsdp_float8_all_gather, scaling_type_weight in choices: float8_linear_config1 = Float8LinearConfig( enable_fsdp_float8_all_gather=False, - cast_config_weight=Float8TensorCastConfig( - scaling_type=scaling_type_weight - ), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), ) float8_linear_config2 = Float8LinearConfig( enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, - cast_config_weight=Float8TensorCastConfig( - scaling_type=scaling_type_weight - ), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), ) module_fp32 = self.init_single_module() ref_module = copy.deepcopy(module_fp32) @@ -464,20 +456,16 @@ def test_fp32_fp8_multi_module_parity(self): """ choices = itertools.product( [False, True], - [TensorScalingType.DYNAMIC, TensorScalingType.DELAYED], + [ScalingType.DYNAMIC, ScalingType.DELAYED], ) for enable_fsdp_float8_all_gather, scaling_type_weight in choices: float8_linear_config1 = Float8LinearConfig( enable_fsdp_float8_all_gather=False, - cast_config_weight=Float8TensorCastConfig( - scaling_type=scaling_type_weight - ), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), ) float8_linear_config2 = Float8LinearConfig( enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, - cast_config_weight=Float8TensorCastConfig( - scaling_type=scaling_type_weight - ), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), ) module = self.init_multi_module().cuda() ref_module = copy.deepcopy(module) @@ -546,9 +534,7 @@ def test_delayed_scaling_inplace_update(self): module = self.init_single_module() float8_linear_config = Float8LinearConfig( enable_fsdp_float8_all_gather=True, - cast_config_weight=Float8TensorCastConfig( - scaling_type=TensorScalingType.DELAYED - ), + cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), ) m_fp8 = convert_to_float8_training( module, diff --git a/test/test_fsdp2/test_fsdp2_common.py b/test/test_fsdp2/test_fsdp2_common.py index 8f1fa80..f26278c 100644 --- a/test/test_fsdp2/test_fsdp2_common.py +++ b/test/test_fsdp2/test_fsdp2_common.py @@ -6,7 +6,7 @@ import torch import torch.distributed as dist import torch.nn as nn -from float8_experimental.config import Float8LinearConfig, TensorScalingType +from float8_experimental.config import Float8LinearConfig, ScalingType from float8_experimental.float8_linear_utils import ( linear_requires_sync, sync_float8_amax_and_scale_history, @@ -44,7 +44,7 @@ def check_parity_no_mp( if ( model is fsdp_model and precompute - and config.cast_config_weight.scaling_type is TensorScalingType.DYNAMIC + and config.cast_config_weight.scaling_type is ScalingType.DYNAMIC ): precompute_float8_dynamic_scale_for_fsdp(model) diff --git a/test/test_fsdp_compile.py b/test/test_fsdp_compile.py index bc80023..e20ab15 100644 --- a/test/test_fsdp_compile.py +++ b/test/test_fsdp_compile.py @@ -18,7 +18,7 @@ import torch.multiprocessing as mp import torch.nn as nn from float8_experimental import Float8LinearConfig -from float8_experimental.config import Float8TensorCastConfig, TensorScalingType +from float8_experimental.config import CastConfig, ScalingType from float8_experimental.float8_linear_utils import ( convert_to_float8_training, sync_float8_amax_and_scale_history, @@ -57,15 +57,9 @@ def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32): # to get around this, we can disable amax init config = Float8LinearConfig( enable_amax_init=False, - cast_config_input=Float8TensorCastConfig( - scaling_type=TensorScalingType.DELAYED - ), - cast_config_weight=Float8TensorCastConfig( - scaling_type=TensorScalingType.DELAYED - ), - cast_config_grad_output=Float8TensorCastConfig( - scaling_type=TensorScalingType.DELAYED - ), + cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), emulate=emulate, ) diff --git a/test/test_inference_flows.py b/test/test_inference_flows.py index 35b640a..421b7a9 100644 --- a/test/test_inference_flows.py +++ b/test/test_inference_flows.py @@ -13,7 +13,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from float8_experimental.config import TensorScalingType +from float8_experimental.config import ScalingType from float8_experimental.float8_linear_utils import convert_to_float8_training from float8_experimental.float8_tensor import Float8Tensor from float8_experimental.float8_utils import compute_error diff --git a/test/test_numerics_integration.py b/test/test_numerics_integration.py index 4d4446c..73e3211 100644 --- a/test/test_numerics_integration.py +++ b/test/test_numerics_integration.py @@ -14,11 +14,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from float8_experimental.config import ( - Float8LinearConfig, - Float8TensorCastConfig, - TensorScalingType, -) +from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType from float8_experimental.float8_linear_utils import ( convert_to_float8_training, linear_requires_sync, @@ -79,22 +75,22 @@ def init_weights(self, init_std: float): class TestFloat8NumericsIntegrationTest: @pytest.mark.parametrize( - "scaling_type_input", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC] ) @pytest.mark.parametrize( - "scaling_type_weight", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC] ) @pytest.mark.parametrize( "scaling_type_grad_output", - [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC], + [ScalingType.DELAYED, ScalingType.DYNAMIC], ) @pytest.mark.skipif(not is_H100, reason="requires H100 GPU") @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") def test_encoder_fw_bw( self, - scaling_type_input: TensorScalingType, - scaling_type_weight: TensorScalingType, - scaling_type_grad_output: TensorScalingType, + scaling_type_input: ScalingType, + scaling_type_weight: ScalingType, + scaling_type_grad_output: ScalingType, ): # TODO(later): maybe add float16 back if it becomes important data_dtype = torch.bfloat16 @@ -114,11 +110,9 @@ def test_encoder_fw_bw( # for now just test the encoder to simplify things model_fp8 = copy.deepcopy(model_ref) config = Float8LinearConfig( - cast_config_input=Float8TensorCastConfig(scaling_type=scaling_type_input), - cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight), - cast_config_grad_output=Float8TensorCastConfig( - scaling_type=scaling_type_grad_output - ), + cast_config_input=CastConfig(scaling_type=scaling_type_input), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), ) convert_to_float8_training( model_fp8,