Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Add for option to use tensor hooks for Dynamic Linear #198

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions float8_experimental/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,8 @@
# this doesn't work with autocast + torch.compile + FSDP. Enabling this
# option is useful for safety, but not strictly necessary.
enable_pre_and_post_forward = True

# If True, dynamic linear uses hooks for activation casting
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to figure out if we want this

# TODO(before land): add test coverage for both cases
# dynamic_use_activation_hooks = True
# dynamic_use_activation_hooks = False
74 changes: 58 additions & 16 deletions float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
A wrapper around a `torch.nn.Linear` module which does fp8 compute.
"""

import float8_experimental.config as config
import torch

from float8_experimental.float8_tensor import Float8Tensor
from float8_experimental.float8_tensor import Float8Tensor, to_fp8_no_autograd
from float8_experimental.float8_utils import tensor_to_scale, to_fp8_saturated


Expand All @@ -31,13 +32,27 @@ def forward(

@staticmethod
def backward(ctx, gradY):
gradY_scale = tensor_to_scale(gradY, torch.float8_e5m2)
gradY_scaled = gradY * gradY_scale
bits_fp8 = to_fp8_saturated(gradY_scaled, torch.float8_e5m2)
return (
Float8Tensor(bits_fp8, gradY_scale, gradY.dtype, emulate=ctx.emulate),
None,
)
fp8_tensor = to_fp8_no_autograd(gradY, torch.float8_e5m2, ctx.emulate)
return fp8_tensor, None


def cast_x_to_float8_e4m3fn_pre_hook(module, args):
"""
Hook to cast the incoming activation to `torch.float8_e4m3fn`
"""
return module.cast_to_float8_e4m3fn(args[0])


def cast_grad_to_float8_e5m2_backward_forward_hook(module, input, output):
"""This is a forward hook that sends the output of the model through
a no-op in the forward but a cast to float8_e5m2 in the backward.

Args:
module (nn.Module): the module to cast the output of
input (Tensor): the input to the module forward call
output (Tensor): the output of the module forward
"""
return module.cast_to_float8_e5m2_bw(output)


class Float8DynamicLinear(torch.nn.Linear):
Expand All @@ -46,38 +61,65 @@ class Float8DynamicLinear(torch.nn.Linear):
conversion to fp8 of the input and weight tensors.
"""

def __init__(self, use_activation_hooks: bool, **super_kwargs):
"""
Args:
use_activation_hooks (bool): whether to use activation hooks for casting to and from float8
"""
super().__init__(**super_kwargs)

self.use_activation_hooks = use_activation_hooks

def forward(self, x):
x_fp8 = self.cast_to_float8(x)
w_fp8 = self.cast_to_float8(self.weight)
# cast x to float8_e4m3fn if not using activation hooks
x_fp8 = x if self.use_activation_hooks else self.cast_to_float8_e4m3fn(x)

# cast w to float8_e4m3fn
w_fp8 = self.cast_to_float8_e4m3fn(self.weight)

y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)

# Cast gradY to float8_e5m2 during backward
y = self.cast_to_float8e5m2_bw(y)
# Cast gradY to float8_e5m2 during backward if not using activation hooks
if not self.use_activation_hooks:
y = self.cast_to_float8_e5m2_bw(y)

return y

def cast_to_float8(self, inpt_tensor):
def cast_to_float8_e4m3fn(self, inpt_tensor: torch.Tensor) -> Float8Tensor:
scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn)
return Float8Tensor.to_float8(
inpt_tensor, scale, torch.float8_e4m3fn, emulate=self.emulate
)

def cast_to_float8e5m2_bw(self, gradY):
def cast_to_float8_e5m2_bw(self, gradY: torch.Tensor) -> torch.Tensor:
return NoopFwToFloat8E5M2Bw.apply(gradY, self.emulate)

@classmethod
def from_float(cls, mod, emulate: bool = False):
def from_float(
cls, mod, emulate: bool = False, use_activation_hooks: bool = False
) -> "Float8DynamicLinear":
"""
Create an nn.Linear with fp8 compute from a regular nn.Linear

Args:
mod (torch.nn.Linear): nn.Linear to convert
emulate (bool): whether to emulate fp8 matmul logic in float32
use_activation_hooks (bool): whether to use activation hooks for casting to and from float8
"""
with torch.device("meta"):
new_mod = cls(mod.in_features, mod.out_features, bias=False)
super_kwargs = {
"in_features": mod.in_features,
"out_features": mod.out_features,
"bias": False,
}
new_mod = cls(use_activation_hooks, **super_kwargs)
new_mod.weight = mod.weight
new_mod.bias = mod.bias
new_mod.emulate = emulate
if new_mod.use_activation_hooks:
# install the hooks
new_mod.register_forward_pre_hook(cast_x_to_float8_e4m3fn_pre_hook)
new_mod.register_forward_hook(
cast_grad_to_float8_e5m2_backward_forward_hook
)
return new_mod
4 changes: 3 additions & 1 deletion float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,14 +298,16 @@ def forward(self, x):
return y

@classmethod
def from_float(cls, mod, emulate: bool = False):
def from_float(cls, mod, emulate: bool = False, use_activation_hooks: bool = False):
"""
Create an nn.Linear with fp8 compute from a regular nn.Linear

Args:
mod (torch.nn.Linear): nn.Linear to convert
emulate (bool): whether to emulate fp8 matmul logic in float32
use_activation_hooks (bool): whether to use activation hooks instead of inlining the casting logic
"""
assert not use_activation_hooks, "use_activation_hooks is not supported yet!"
# TODO Follow up! This is a great idea but we need the mixin base to create real
# Tensors and the Linear base to create empty params
# with torch.device("meta"):
Expand Down
13 changes: 10 additions & 3 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,30 @@ class LinearType(Enum):


def get_float8_linear(
linear_type: LinearType, linear_ref: torch.nn.Linear, emulate: bool = False
linear_type: LinearType,
linear_ref: torch.nn.Linear,
emulate: bool = False,
use_activation_hooks: bool = False,
):
"""Returns a Float8Linear module of the given type, initialized from linear_ref.
Args:
linear_type: The type of Float8Linear to return.
linear_ref: The linear module to initialize from.
emulate: Whether to emulate the fp8 matmul logic in float32.
use_activation_hooks: Whether to use activation hooks for dynamic linear.
"""
LINEAR_TYPE_MAP = {
LinearType.DELAYED: Float8Linear,
LinearType.DYNAMIC: Float8DynamicLinear,
}
if linear_type not in LINEAR_TYPE_MAP:
raise ValueError(f"linear_type must be one of {LINEAR_TYPE_MAP.keys()}")

if use_activation_hooks and linear_type != LinearType.DYNAMIC:
raise ValueError("use_activation_hooks is only supported for dynamic linear")
return LINEAR_TYPE_MAP[linear_type].from_float(
copy.deepcopy(linear_ref), emulate=emulate
copy.deepcopy(linear_ref),
emulate=emulate,
use_activation_hooks=use_activation_hooks,
)


Expand Down
25 changes: 24 additions & 1 deletion float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@

import torch

from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated
from float8_experimental.float8_utils import (
tensor_to_amax,
tensor_to_scale,
to_fp8_saturated,
)

aten = torch.ops.aten

Expand Down Expand Up @@ -170,3 +174,22 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):

# Do not force the Float8Tensor type on the returned tensor
__torch_function__ = torch._C._disabled_torch_function_impl


def to_fp8_no_autograd(
x: torch.Tensor, float8_dtype: torch.dtype, emulate: bool
) -> Float8Tensor:
"""Convert a tensor to float8 without autograd
This is used in multiple places in the codebase to convert a tensor to float8

This function will calculate the scale, do the scaling, and then convert to a Float8Tensor
Args:
x: the tensor to convert
scale: the scale to use to convert the tensor
float8_dtype: the float8 dtype to use
emulate: whether to emulate the matmuls in fp32
"""
x_scale = tensor_to_scale(x, float8_dtype)
x_scaled = x * x_scale
bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype)
return Float8Tensor(bits_fp8, x_scale, x.dtype, emulate=emulate)
20 changes: 20 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest


@pytest.fixture
def x_fail_activation_hooks(request):
use_activation_hooks = request.getfixturevalue("use_activation_hooks")
if use_activation_hooks:
request.node.add_marker(
pytest.mark.xfail(reason="use_activation_hooks is not supported for AOT")
)


@pytest.fixture
def x_fail_activation_hooks_with_delayed(request):
linear_type = request.getfixturevalue("linear_type")
use_activation_hooks = request.getfixturevalue("use_activation_hooks")
if use_activation_hooks and linear_type == linear_type.DELAYED:
request.node.add_marker(
pytest.mark.xfail(reason="use_activation_hooks is not supported for AOT")
)
76 changes: 60 additions & 16 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,15 @@ def test_preserves_dtype(self) -> None:


class TestFloat8Linear:
def _test_linear_impl(self, x, m_ref, linear_type: LinearType, emulate: bool):
m_fp8 = get_float8_linear(linear_type, m_ref, emulate)
def _test_linear_impl(
self,
x,
m_ref,
linear_type: LinearType,
emulate: bool,
use_activation_hooks: bool = False,
):
m_fp8 = get_float8_linear(linear_type, m_ref, emulate, use_activation_hooks)
for _ in range(2):
if linear_requires_sync(linear_type):
sync_float8_amax_and_scale_history(m_fp8)
Expand Down Expand Up @@ -112,7 +119,15 @@ def _test_linear_impl(self, x, m_ref, linear_type: LinearType, emulate: bool):
@pytest.mark.parametrize("emulate", [True, False])
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC])
def test_linear_nobias(self, x_shape, linear_type: LinearType, emulate: bool):
@pytest.mark.parametrize("use_activation_hooks", [True, False])
@pytest.mark.usefixtures("x_fail_activation_hooks_with_delayed")
def test_linear_nobias(
self,
x_shape,
linear_type: LinearType,
emulate: bool,
use_activation_hooks: bool,
):
if not emulate:
if not torch.cuda.is_available():
warnings.warn("CUDA not available")
Expand All @@ -125,16 +140,23 @@ def test_linear_nobias(self, x_shape, linear_type: LinearType, emulate: bool):

x = torch.randn(*x_shape, device="cuda")
m_ref = nn.Linear(16, 32, bias=False, device="cuda")
self._test_linear_impl(x, m_ref, linear_type, emulate)
self._test_linear_impl(x, m_ref, linear_type, emulate, use_activation_hooks)

@pytest.mark.parametrize("emulate", [True, False])
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC])
@pytest.mark.parametrize(
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
)
@pytest.mark.parametrize("use_activation_hooks", [True, False])
@pytest.mark.usefixtures("x_fail_activation_hooks_with_delayed")
def test_linear_bias(
self, x_shape, linear_type: LinearType, emulate: bool, linear_dtype: torch.dtype
self,
x_shape,
linear_type: LinearType,
emulate: bool,
linear_dtype: torch.dtype,
use_activation_hooks: bool,
):
if not emulate:
if not torch.cuda.is_available():
Expand All @@ -148,25 +170,52 @@ def test_linear_bias(

x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype)
self._test_linear_impl(x, m_ref, linear_type, emulate)
self._test_linear_impl(x, m_ref, linear_type, emulate, use_activation_hooks)

m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
m = Float8Linear.from_float(m, emulate)
@pytest.mark.parametrize("emulate", [True, False])
@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC])
@pytest.mark.parametrize(
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
)
@pytest.mark.parametrize("use_activation_hooks", [True, False])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was some testing that was globbed together before this split the test into two

@pytest.mark.usefixtures("x_fail_activation_hooks_with_delayed")
def test_autocast_outputs(
self,
linear_type: LinearType,
emulate: bool,
linear_dtype: torch.dtype,
use_activation_hooks: bool,
):
if not emulate:
if not torch.cuda.is_available():
warnings.warn("CUDA not available")
pytest.skip()
elif torch.cuda.get_device_capability() < (9, 0):
warnings.warn(
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
)
pytest.skip()

m_ref = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
m = get_float8_linear(linear_type, m_ref, emulate, use_activation_hooks)

# autocast off
x = torch.randn(16, 32, device="cuda", dtype=linear_dtype)
sync_float8_amax_and_scale_history(m)
if linear_requires_sync(linear_type):
sync_float8_amax_and_scale_history(m)
y = m(x)
assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}"

# autocast on
with torch.autocast("cuda"):
sync_float8_amax_and_scale_history(m)
if linear_requires_sync(linear_type):
sync_float8_amax_and_scale_history(m)
y = m(x)
assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}"

with torch.autocast("cuda", dtype=torch.bfloat16):
sync_float8_amax_and_scale_history(m)
if linear_requires_sync(linear_type):
sync_float8_amax_and_scale_history(m)
y = m(x)
assert (
y.dtype == torch.bfloat16
Expand All @@ -180,11 +229,6 @@ def test_type_cast(self, linear_type: LinearType, linear_dtype: torch.dtype):
emulate = (
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0)
)
x_shape = (16, 16)

x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype)
self._test_linear_impl(x, m_ref, linear_type, emulate)

m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
m = Float8Linear.from_float(m, emulate)
Expand Down
Loading
Loading