-
Notifications
You must be signed in to change notification settings - Fork 20
Add for option to use tensor hooks for Dynamic Linear #198
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,9 +7,10 @@ | |
A wrapper around a `torch.nn.Linear` module which does fp8 compute. | ||
""" | ||
|
||
import float8_experimental.config as config | ||
import torch | ||
|
||
from float8_experimental.float8_tensor import Float8Tensor | ||
from float8_experimental.float8_tensor import Float8Tensor, to_fp8_no_autograd | ||
from float8_experimental.float8_utils import tensor_to_scale, to_fp8_saturated | ||
|
||
|
||
|
@@ -31,13 +32,23 @@ def forward( | |
|
||
@staticmethod | ||
def backward(ctx, gradY): | ||
gradY_scale = tensor_to_scale(gradY, torch.float8_e5m2) | ||
gradY_scaled = gradY * gradY_scale | ||
bits_fp8 = to_fp8_saturated(gradY_scaled, torch.float8_e5m2) | ||
return ( | ||
Float8Tensor(bits_fp8, gradY_scale, gradY.dtype, emulate=ctx.emulate), | ||
None, | ||
) | ||
fp8_tensor = to_fp8_no_autograd(gradY, torch.float8_e5m2, ctx.emulate) | ||
return fp8_tensor, None | ||
|
||
|
||
def cast_x_to_float8_e4m3fn_pre_hook(module, args): | ||
""" | ||
Hook to cast the incoming activation to `torch.float8_e4m3fn` | ||
""" | ||
return module.cast_to_float8_e4m3fn(args[0]) | ||
|
||
|
||
def cast_dldy_to_float8_e5m2_backward_pre_hook(module, grad_output): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. testing my understanding: It looks to me that the torch.compile issue is related to the backward_prehook and tensor subclass interactions. But could this be solved by using This would solve the backward prehook issue I believe, but not sure if it would hit subclass issues (hopefully not) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ahhh great, idea unfortunately still erroring for compile.. I think that is a little cleaner, is that enough for a good interaction with DTensor? I have this locally and can push up but don't know which one ultimately gets us closer, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I think a forward_hook might be a little cleaner and easier for DTensor to interact. We would need to actually try composing those hooks to see if there's any gap. I feel let's try land the forward_hook approach in this PR, and if we found we need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sounds good! |
||
""" | ||
Hook to cast the incoming gradient to `torch.float8_e5m2` | ||
""" | ||
gradY = grad_output[0] | ||
return (to_fp8_no_autograd(gradY, torch.float8_e5m2, module.emulate),) | ||
|
||
|
||
class Float8DynamicLinear(torch.nn.Linear): | ||
|
@@ -46,38 +57,65 @@ class Float8DynamicLinear(torch.nn.Linear): | |
conversion to fp8 of the input and weight tensors. | ||
""" | ||
|
||
def __init__(self, use_activation_hooks: bool, **super_kwargs): | ||
""" | ||
Args: | ||
use_activation_hooks (bool): whether to use activation hooks for casting to and from float8 | ||
""" | ||
super().__init__(**super_kwargs) | ||
|
||
self.use_activation_hooks = use_activation_hooks | ||
|
||
def forward(self, x): | ||
x_fp8 = self.cast_to_float8(x) | ||
w_fp8 = self.cast_to_float8(self.weight) | ||
# cast x to float8_e4m3fn if not using activation hooks | ||
x_fp8 = x if self.use_activation_hooks else self.cast_to_float8_e4m3fn(x) | ||
|
||
# cast w to float8_e4m3fn | ||
w_fp8 = self.cast_to_float8_e4m3fn(self.weight) | ||
|
||
y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias) | ||
|
||
# Cast gradY to float8_e5m2 during backward | ||
y = self.cast_to_float8e5m2_bw(y) | ||
# Cast gradY to float8_e5m2 during backward if not using activation hooks | ||
if not self.use_activation_hooks: | ||
y = self.cast_to_float8_e5m2_bw(y) | ||
|
||
return y | ||
|
||
def cast_to_float8(self, inpt_tensor): | ||
def cast_to_float8_e4m3fn(self, inpt_tensor: torch.Tensor) -> Float8Tensor: | ||
scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn) | ||
return Float8Tensor.to_float8( | ||
inpt_tensor, scale, torch.float8_e4m3fn, emulate=self.emulate | ||
) | ||
|
||
def cast_to_float8e5m2_bw(self, gradY): | ||
def cast_to_float8_e5m2_bw(self, gradY: torch.Tensor) -> torch.Tensor: | ||
return NoopFwToFloat8E5M2Bw.apply(gradY, self.emulate) | ||
|
||
@classmethod | ||
def from_float(cls, mod, emulate: bool = False): | ||
def from_float( | ||
cls, mod, emulate: bool = False, use_activation_hooks: bool = False | ||
) -> "Float8DynamicLinear": | ||
""" | ||
Create an nn.Linear with fp8 compute from a regular nn.Linear | ||
|
||
Args: | ||
mod (torch.nn.Linear): nn.Linear to convert | ||
emulate (bool): whether to emulate fp8 matmul logic in float32 | ||
use_activation_hooks (bool): whether to use activation hooks for casting to and from float8 | ||
""" | ||
with torch.device("meta"): | ||
new_mod = cls(mod.in_features, mod.out_features, bias=False) | ||
super_kwargs = { | ||
"in_features": mod.in_features, | ||
"out_features": mod.out_features, | ||
"bias": False, | ||
} | ||
new_mod = cls(use_activation_hooks, **super_kwargs) | ||
new_mod.weight = mod.weight | ||
new_mod.bias = mod.bias | ||
new_mod.emulate = emulate | ||
if new_mod.use_activation_hooks: | ||
# install the hooks | ||
new_mod.register_forward_pre_hook(cast_x_to_float8_e4m3fn_pre_hook) | ||
new_mod.register_full_backward_pre_hook( | ||
cast_dldy_to_float8_e5m2_backward_pre_hook | ||
) | ||
return new_mod |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -50,8 +50,15 @@ def test_preserves_dtype(self) -> None: | |
|
||
|
||
class TestFloat8Linear: | ||
def _test_linear_impl(self, x, m_ref, linear_type: LinearType, emulate: bool): | ||
m_fp8 = get_float8_linear(linear_type, m_ref, emulate) | ||
def _test_linear_impl( | ||
self, | ||
x, | ||
m_ref, | ||
linear_type: LinearType, | ||
emulate: bool, | ||
use_activation_hooks: bool = False, | ||
): | ||
m_fp8 = get_float8_linear(linear_type, m_ref, emulate, use_activation_hooks) | ||
for _ in range(2): | ||
if linear_requires_sync(linear_type): | ||
sync_float8_amax_and_scale_history(m_fp8) | ||
|
@@ -112,7 +119,14 @@ def _test_linear_impl(self, x, m_ref, linear_type: LinearType, emulate: bool): | |
@pytest.mark.parametrize("emulate", [True, False]) | ||
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) | ||
@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) | ||
def test_linear_nobias(self, x_shape, linear_type: LinearType, emulate: bool): | ||
@pytest.mark.parametrize("use_activation_hooks", [True, False]) | ||
def test_linear_nobias( | ||
self, | ||
x_shape, | ||
linear_type: LinearType, | ||
emulate: bool, | ||
use_activation_hooks: bool, | ||
): | ||
if not emulate: | ||
if not torch.cuda.is_available(): | ||
warnings.warn("CUDA not available") | ||
|
@@ -122,19 +136,27 @@ def test_linear_nobias(self, x_shape, linear_type: LinearType, emulate: bool): | |
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)" | ||
) | ||
pytest.skip() | ||
if use_activation_hooks and linear_type != LinearType.DYNAMIC: | ||
pytest.skip("use_activation_hooks is only supported for dynamic linear") | ||
|
||
x = torch.randn(*x_shape, device="cuda") | ||
m_ref = nn.Linear(16, 32, bias=False, device="cuda") | ||
self._test_linear_impl(x, m_ref, linear_type, emulate) | ||
self._test_linear_impl(x, m_ref, linear_type, emulate, use_activation_hooks) | ||
|
||
@pytest.mark.parametrize("emulate", [True, False]) | ||
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) | ||
@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) | ||
@pytest.mark.parametrize( | ||
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32] | ||
) | ||
@pytest.mark.parametrize("use_activation_hooks", [True, False]) | ||
def test_linear_bias( | ||
self, x_shape, linear_type: LinearType, emulate: bool, linear_dtype: torch.dtype | ||
self, | ||
x_shape, | ||
linear_type: LinearType, | ||
emulate: bool, | ||
linear_dtype: torch.dtype, | ||
use_activation_hooks: bool, | ||
): | ||
if not emulate: | ||
if not torch.cuda.is_available(): | ||
|
@@ -146,27 +168,59 @@ def test_linear_bias( | |
) | ||
pytest.skip() | ||
|
||
if use_activation_hooks and linear_type != LinearType.DYNAMIC: | ||
pytest.skip("use_activation_hooks is only supported for dynamic linear") | ||
|
||
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) | ||
m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) | ||
self._test_linear_impl(x, m_ref, linear_type, emulate) | ||
self._test_linear_impl(x, m_ref, linear_type, emulate, use_activation_hooks) | ||
|
||
m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) | ||
m = Float8Linear.from_float(m, emulate) | ||
@pytest.mark.parametrize("emulate", [True, False]) | ||
@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) | ||
@pytest.mark.parametrize( | ||
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32] | ||
) | ||
@pytest.mark.parametrize("use_activation_hooks", [True, False]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was some testing that was globbed together before this split the test into two |
||
def test_autocast_outputs( | ||
self, | ||
linear_type: LinearType, | ||
emulate: bool, | ||
linear_dtype: torch.dtype, | ||
use_activation_hooks: bool, | ||
): | ||
if not emulate: | ||
if not torch.cuda.is_available(): | ||
warnings.warn("CUDA not available") | ||
pytest.skip() | ||
elif torch.cuda.get_device_capability() < (9, 0): | ||
warnings.warn( | ||
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)" | ||
) | ||
pytest.skip() | ||
|
||
if use_activation_hooks and linear_type != LinearType.DYNAMIC: | ||
pytest.skip("use_activation_hooks is only supported for dynamic linear") | ||
|
||
m_ref = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) | ||
m = get_float8_linear(linear_type, m_ref, emulate, use_activation_hooks) | ||
|
||
# autocast off | ||
x = torch.randn(16, 32, device="cuda", dtype=linear_dtype) | ||
sync_float8_amax_and_scale_history(m) | ||
if linear_requires_sync(linear_type): | ||
sync_float8_amax_and_scale_history(m) | ||
y = m(x) | ||
assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}" | ||
|
||
# autocast on | ||
with torch.autocast("cuda"): | ||
sync_float8_amax_and_scale_history(m) | ||
if linear_requires_sync(linear_type): | ||
sync_float8_amax_and_scale_history(m) | ||
y = m(x) | ||
assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}" | ||
|
||
with torch.autocast("cuda", dtype=torch.bfloat16): | ||
sync_float8_amax_and_scale_history(m) | ||
if linear_requires_sync(linear_type): | ||
sync_float8_amax_and_scale_history(m) | ||
y = m(x) | ||
assert ( | ||
y.dtype == torch.bfloat16 | ||
|
@@ -180,11 +234,6 @@ def test_type_cast(self, linear_type: LinearType, linear_dtype: torch.dtype): | |
emulate = ( | ||
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0) | ||
) | ||
x_shape = (16, 16) | ||
|
||
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) | ||
m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) | ||
self._test_linear_impl(x, m_ref, linear_type, emulate) | ||
|
||
m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) | ||
m = Float8Linear.from_float(m, emulate) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to figure out if we want this