diff --git a/float8_experimental/config.py b/float8_experimental/config.py index f0ba914f..9df065bc 100644 --- a/float8_experimental/config.py +++ b/float8_experimental/config.py @@ -14,3 +14,8 @@ # this doesn't work with autocast + torch.compile + FSDP. Enabling this # option is useful for safety, but not strictly necessary. enable_pre_and_post_forward = True + +# If True, dynamic linear uses hooks for activation casting +# TODO(before land): add test coverage for both cases +# dynamic_use_activation_hooks = True +# dynamic_use_activation_hooks = False diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py index 58e352da..2f4905fe 100644 --- a/float8_experimental/float8_dynamic_linear.py +++ b/float8_experimental/float8_dynamic_linear.py @@ -7,9 +7,10 @@ A wrapper around a `torch.nn.Linear` module which does fp8 compute. """ +import float8_experimental.config as config import torch -from float8_experimental.float8_tensor import Float8Tensor +from float8_experimental.float8_tensor import Float8Tensor, to_fp8_no_autograd from float8_experimental.float8_utils import tensor_to_scale, to_fp8_saturated @@ -31,13 +32,27 @@ def forward( @staticmethod def backward(ctx, gradY): - gradY_scale = tensor_to_scale(gradY, torch.float8_e5m2) - gradY_scaled = gradY * gradY_scale - bits_fp8 = to_fp8_saturated(gradY_scaled, torch.float8_e5m2) - return ( - Float8Tensor(bits_fp8, gradY_scale, gradY.dtype, emulate=ctx.emulate), - None, - ) + fp8_tensor = to_fp8_no_autograd(gradY, torch.float8_e5m2, ctx.emulate) + return fp8_tensor, None + + +def cast_x_to_float8_e4m3fn_pre_hook(module, args): + """ + Hook to cast the incoming activation to `torch.float8_e4m3fn` + """ + return module.cast_to_float8_e4m3fn(args[0]) + + +def cast_grad_to_float8_e5m2_backward_forward_hook(module, input, output): + """This is a forward hook that sends the output of the model through + a no-op in the forward but a cast to float8_e5m2 in the backward. + + Args: + module (nn.Module): the module to cast the output of + input (Tensor): the input to the module forward call + output (Tensor): the output of the module forward + """ + return module.cast_to_float8_e5m2_bw(output) class Float8DynamicLinear(torch.nn.Linear): @@ -46,38 +61,65 @@ class Float8DynamicLinear(torch.nn.Linear): conversion to fp8 of the input and weight tensors. """ + def __init__(self, use_activation_hooks: bool, **super_kwargs): + """ + Args: + use_activation_hooks (bool): whether to use activation hooks for casting to and from float8 + """ + super().__init__(**super_kwargs) + + self.use_activation_hooks = use_activation_hooks + def forward(self, x): - x_fp8 = self.cast_to_float8(x) - w_fp8 = self.cast_to_float8(self.weight) + # cast x to float8_e4m3fn if not using activation hooks + x_fp8 = x if self.use_activation_hooks else self.cast_to_float8_e4m3fn(x) + + # cast w to float8_e4m3fn + w_fp8 = self.cast_to_float8_e4m3fn(self.weight) y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias) - # Cast gradY to float8_e5m2 during backward - y = self.cast_to_float8e5m2_bw(y) + # Cast gradY to float8_e5m2 during backward if not using activation hooks + if not self.use_activation_hooks: + y = self.cast_to_float8_e5m2_bw(y) return y - def cast_to_float8(self, inpt_tensor): + def cast_to_float8_e4m3fn(self, inpt_tensor: torch.Tensor) -> Float8Tensor: scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn) return Float8Tensor.to_float8( inpt_tensor, scale, torch.float8_e4m3fn, emulate=self.emulate ) - def cast_to_float8e5m2_bw(self, gradY): + def cast_to_float8_e5m2_bw(self, gradY: torch.Tensor) -> torch.Tensor: return NoopFwToFloat8E5M2Bw.apply(gradY, self.emulate) @classmethod - def from_float(cls, mod, emulate: bool = False): + def from_float( + cls, mod, emulate: bool = False, use_activation_hooks: bool = False + ) -> "Float8DynamicLinear": """ Create an nn.Linear with fp8 compute from a regular nn.Linear Args: mod (torch.nn.Linear): nn.Linear to convert emulate (bool): whether to emulate fp8 matmul logic in float32 + use_activation_hooks (bool): whether to use activation hooks for casting to and from float8 """ with torch.device("meta"): - new_mod = cls(mod.in_features, mod.out_features, bias=False) + super_kwargs = { + "in_features": mod.in_features, + "out_features": mod.out_features, + "bias": False, + } + new_mod = cls(use_activation_hooks, **super_kwargs) new_mod.weight = mod.weight new_mod.bias = mod.bias new_mod.emulate = emulate + if new_mod.use_activation_hooks: + # install the hooks + new_mod.register_forward_pre_hook(cast_x_to_float8_e4m3fn_pre_hook) + new_mod.register_forward_hook( + cast_grad_to_float8_e5m2_backward_forward_hook + ) return new_mod diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 59285852..8c2719ce 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -298,14 +298,16 @@ def forward(self, x): return y @classmethod - def from_float(cls, mod, emulate: bool = False): + def from_float(cls, mod, emulate: bool = False, use_activation_hooks: bool = False): """ Create an nn.Linear with fp8 compute from a regular nn.Linear Args: mod (torch.nn.Linear): nn.Linear to convert emulate (bool): whether to emulate fp8 matmul logic in float32 + use_activation_hooks (bool): whether to use activation hooks instead of inlining the casting logic """ + assert not use_activation_hooks, "use_activation_hooks is not supported yet!" # TODO Follow up! This is a great idea but we need the mixin base to create real # Tensors and the Linear base to create empty params # with torch.device("meta"): diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index c6152e8d..5d954bd1 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -23,13 +23,17 @@ class LinearType(Enum): def get_float8_linear( - linear_type: LinearType, linear_ref: torch.nn.Linear, emulate: bool = False + linear_type: LinearType, + linear_ref: torch.nn.Linear, + emulate: bool = False, + use_activation_hooks: bool = False, ): """Returns a Float8Linear module of the given type, initialized from linear_ref. Args: linear_type: The type of Float8Linear to return. linear_ref: The linear module to initialize from. emulate: Whether to emulate the fp8 matmul logic in float32. + use_activation_hooks: Whether to use activation hooks for dynamic linear. """ LINEAR_TYPE_MAP = { LinearType.DELAYED: Float8Linear, @@ -37,9 +41,12 @@ def get_float8_linear( } if linear_type not in LINEAR_TYPE_MAP: raise ValueError(f"linear_type must be one of {LINEAR_TYPE_MAP.keys()}") - + if use_activation_hooks and linear_type != LinearType.DYNAMIC: + raise ValueError("use_activation_hooks is only supported for dynamic linear") return LINEAR_TYPE_MAP[linear_type].from_float( - copy.deepcopy(linear_ref), emulate=emulate + copy.deepcopy(linear_ref), + emulate=emulate, + use_activation_hooks=use_activation_hooks, ) diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 4450fce8..6063f7c1 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -7,7 +7,11 @@ import torch -from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated +from float8_experimental.float8_utils import ( + tensor_to_amax, + tensor_to_scale, + to_fp8_saturated, +) aten = torch.ops.aten @@ -170,3 +174,22 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): # Do not force the Float8Tensor type on the returned tensor __torch_function__ = torch._C._disabled_torch_function_impl + + +def to_fp8_no_autograd( + x: torch.Tensor, float8_dtype: torch.dtype, emulate: bool +) -> Float8Tensor: + """Convert a tensor to float8 without autograd + This is used in multiple places in the codebase to convert a tensor to float8 + + This function will calculate the scale, do the scaling, and then convert to a Float8Tensor + Args: + x: the tensor to convert + scale: the scale to use to convert the tensor + float8_dtype: the float8 dtype to use + emulate: whether to emulate the matmuls in fp32 + """ + x_scale = tensor_to_scale(x, float8_dtype) + x_scaled = x * x_scale + bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype) + return Float8Tensor(bits_fp8, x_scale, x.dtype, emulate=emulate) diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 00000000..289e1514 --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,20 @@ +import pytest + + +@pytest.fixture +def x_fail_activation_hooks(request): + use_activation_hooks = request.getfixturevalue("use_activation_hooks") + if use_activation_hooks: + request.node.add_marker( + pytest.mark.xfail(reason="use_activation_hooks is not supported for AOT") + ) + + +@pytest.fixture +def x_fail_activation_hooks_with_delayed(request): + linear_type = request.getfixturevalue("linear_type") + use_activation_hooks = request.getfixturevalue("use_activation_hooks") + if use_activation_hooks and linear_type == linear_type.DELAYED: + request.node.add_marker( + pytest.mark.xfail(reason="use_activation_hooks is not supported for AOT") + ) diff --git a/test/test_base.py b/test/test_base.py index ba1f6662..bf19253d 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -50,8 +50,15 @@ def test_preserves_dtype(self) -> None: class TestFloat8Linear: - def _test_linear_impl(self, x, m_ref, linear_type: LinearType, emulate: bool): - m_fp8 = get_float8_linear(linear_type, m_ref, emulate) + def _test_linear_impl( + self, + x, + m_ref, + linear_type: LinearType, + emulate: bool, + use_activation_hooks: bool = False, + ): + m_fp8 = get_float8_linear(linear_type, m_ref, emulate, use_activation_hooks) for _ in range(2): if linear_requires_sync(linear_type): sync_float8_amax_and_scale_history(m_fp8) @@ -112,7 +119,15 @@ def _test_linear_impl(self, x, m_ref, linear_type: LinearType, emulate: bool): @pytest.mark.parametrize("emulate", [True, False]) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) - def test_linear_nobias(self, x_shape, linear_type: LinearType, emulate: bool): + @pytest.mark.parametrize("use_activation_hooks", [True, False]) + @pytest.mark.usefixtures("x_fail_activation_hooks_with_delayed") + def test_linear_nobias( + self, + x_shape, + linear_type: LinearType, + emulate: bool, + use_activation_hooks: bool, + ): if not emulate: if not torch.cuda.is_available(): warnings.warn("CUDA not available") @@ -125,7 +140,7 @@ def test_linear_nobias(self, x_shape, linear_type: LinearType, emulate: bool): x = torch.randn(*x_shape, device="cuda") m_ref = nn.Linear(16, 32, bias=False, device="cuda") - self._test_linear_impl(x, m_ref, linear_type, emulate) + self._test_linear_impl(x, m_ref, linear_type, emulate, use_activation_hooks) @pytest.mark.parametrize("emulate", [True, False]) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @@ -133,8 +148,15 @@ def test_linear_nobias(self, x_shape, linear_type: LinearType, emulate: bool): @pytest.mark.parametrize( "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] ) + @pytest.mark.parametrize("use_activation_hooks", [True, False]) + @pytest.mark.usefixtures("x_fail_activation_hooks_with_delayed") def test_linear_bias( - self, x_shape, linear_type: LinearType, emulate: bool, linear_dtype: torch.dtype + self, + x_shape, + linear_type: LinearType, + emulate: bool, + linear_dtype: torch.dtype, + use_activation_hooks: bool, ): if not emulate: if not torch.cuda.is_available(): @@ -148,25 +170,52 @@ def test_linear_bias( x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) - self._test_linear_impl(x, m_ref, linear_type, emulate) + self._test_linear_impl(x, m_ref, linear_type, emulate, use_activation_hooks) - m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) - m = Float8Linear.from_float(m, emulate) + @pytest.mark.parametrize("emulate", [True, False]) + @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) + @pytest.mark.parametrize( + "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] + ) + @pytest.mark.parametrize("use_activation_hooks", [True, False]) + @pytest.mark.usefixtures("x_fail_activation_hooks_with_delayed") + def test_autocast_outputs( + self, + linear_type: LinearType, + emulate: bool, + linear_dtype: torch.dtype, + use_activation_hooks: bool, + ): + if not emulate: + if not torch.cuda.is_available(): + warnings.warn("CUDA not available") + pytest.skip() + elif torch.cuda.get_device_capability() < (9, 0): + warnings.warn( + f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)" + ) + pytest.skip() + + m_ref = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) + m = get_float8_linear(linear_type, m_ref, emulate, use_activation_hooks) # autocast off x = torch.randn(16, 32, device="cuda", dtype=linear_dtype) - sync_float8_amax_and_scale_history(m) + if linear_requires_sync(linear_type): + sync_float8_amax_and_scale_history(m) y = m(x) assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}" # autocast on with torch.autocast("cuda"): - sync_float8_amax_and_scale_history(m) + if linear_requires_sync(linear_type): + sync_float8_amax_and_scale_history(m) y = m(x) assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}" with torch.autocast("cuda", dtype=torch.bfloat16): - sync_float8_amax_and_scale_history(m) + if linear_requires_sync(linear_type): + sync_float8_amax_and_scale_history(m) y = m(x) assert ( y.dtype == torch.bfloat16 @@ -180,11 +229,6 @@ def test_type_cast(self, linear_type: LinearType, linear_dtype: torch.dtype): emulate = ( not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0) ) - x_shape = (16, 16) - - x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) - m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) - self._test_linear_impl(x, m_ref, linear_type, emulate) m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) m = Float8Linear.from_float(m, emulate) diff --git a/test/test_compile.py b/test/test_compile.py index d39b7400..d0ebac61 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -27,6 +27,7 @@ def _test_compile_base( emulate: bool, linear_type: LinearType, dtype: torch.dtype, + use_activation_hooks: bool, ): random.seed(0) torch.manual_seed(0) @@ -36,7 +37,7 @@ def _test_compile_base( x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) - m_fp8 = get_float8_linear(linear_type, m_ref, emulate=emulate) + m_fp8 = get_float8_linear(linear_type, m_ref, emulate, use_activation_hooks) m_fp8 = torch.compile(m_fp8, backend=backend, fullgraph=fullgraph) m_ref = torch.compile(m_ref, backend=backend, fullgraph=fullgraph) @@ -55,30 +56,67 @@ def _test_compile_base( @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) @pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +@pytest.mark.parametrize("use_activation_hooks", [False, True]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") -def test_eager_only(fullgraph, emulate: bool, linear_type: bool, dtype: torch.dtype): +def test_eager_only( + fullgraph, + emulate: bool, + linear_type: bool, + dtype: torch.dtype, + use_activation_hooks: bool, +): + if linear_type == LinearType.DELAYED and use_activation_hooks: + pytest.skip("use_activation_hooks is only supported for dynamic linear") torch._dynamo.reset() - _test_compile_base("eager", fullgraph, emulate, linear_type, dtype) + _test_compile_base( + "eager", fullgraph, emulate, linear_type, dtype, use_activation_hooks + ) @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True]) @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +@pytest.mark.parametrize("use_activation_hooks", [False, True]) +# TODO this shouldn't fail but multiple fake modes +@pytest.mark.usefixtures("x_fail_activation_hooks") @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") -def test_aot_eager(fullgraph, emulate: bool, linear_type: bool, dtype: torch.dtype): +def test_aot_eager( + fullgraph, + emulate: bool, + linear_type: bool, + dtype: torch.dtype, + use_activation_hooks: bool, +): + if linear_type == LinearType.DELAYED and use_activation_hooks: + pytest.skip("use_activation_hooks is only supported for dynamic linear") torch._dynamo.reset() - _test_compile_base("aot_eager", fullgraph, emulate, linear_type, dtype) + _test_compile_base( + "aot_eager", fullgraph, emulate, linear_type, dtype, use_activation_hooks + ) @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize("emulate", [False]) @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) +@pytest.mark.parametrize("use_activation_hooks", [False, True]) +# TODO this shouldn't fail but multiple fake modes +@pytest.mark.usefixtures("x_fail_activation_hooks") @unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA not available") @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) -def test_inductor(fullgraph, emulate: bool, linear_type: bool, dtype: torch.dtype): +def test_inductor( + fullgraph, + emulate: bool, + linear_type: bool, + dtype: torch.dtype, + use_activation_hooks: bool, +): + if linear_type == LinearType.DELAYED and use_activation_hooks: + pytest.skip("use_activation_hooks is only supported for dynamic linear") torch._dynamo.reset() - _test_compile_base("inductor", fullgraph, emulate, linear_type, dtype) + _test_compile_base( + "inductor", fullgraph, emulate, linear_type, dtype, use_activation_hooks + ) class TestGraphBreaks(DynamoTestCase):