Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
Update on "bring back torch.autograd.Function for float8 matmul"
Browse files Browse the repository at this point in the history
Summary:

This is a redo of
#316

With upcoming support of scaling granularities other than tensorwise,
we need a good way to control which gemm kernel to call and how to scale
the input tensors in fwd and bwd. A `torch.autograd.Function` override
is the cleanest way to do that, and in 2024 this now works with
`torch.compile`.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
  • Loading branch information
vkuzo committed Jul 25, 2024
2 parents 1c3e320 + dbd5d02 commit 63efec6
Show file tree
Hide file tree
Showing 17 changed files with 138 additions and 225 deletions.
20 changes: 7 additions & 13 deletions benchmarks/bench_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,7 @@

import torch
import torch.utils.benchmark as benchmark
from float8_experimental.config import (
Float8LinearConfig,
Float8TensorCastConfig,
TensorScalingType,
)
from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import (
linear_requires_sync,
Expand Down Expand Up @@ -107,15 +103,13 @@ def main(
device = "cuda"
print(f"Compile is set to | {compile}")

scaling_type_input = TensorScalingType(scaling_type_input)
scaling_type_weight = TensorScalingType(scaling_type_weight)
scaling_type_grad_output = TensorScalingType(scaling_type_grad_output)
scaling_type_input = ScalingType(scaling_type_input)
scaling_type_weight = ScalingType(scaling_type_weight)
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
config = Float8LinearConfig(
cast_config_input=Float8TensorCastConfig(scaling_type=scaling_type_input),
cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight),
cast_config_grad_output=Float8TensorCastConfig(
scaling_type=scaling_type_grad_output
),
cast_config_input=CastConfig(scaling_type=scaling_type_input),
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
)

# LLaMa 2 70B single-node weight shapes
Expand Down
14 changes: 4 additions & 10 deletions benchmarks/bench_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,7 @@
import torch.multiprocessing as mp
import torch.nn as nn
import torch.utils.benchmark as benchmark
from float8_experimental.config import (
Float8LinearConfig,
Float8TensorCastConfig,
TensorScalingType,
)
from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType
from float8_experimental.float8_linear_utils import (
convert_to_float8_training,
sync_float8_amax_and_scale_history,
Expand All @@ -33,11 +29,9 @@
lr = 0.01

config = Float8LinearConfig(
cast_config_input=Float8TensorCastConfig(scaling_type=TensorScalingType.DELAYED),
cast_config_weight=Float8TensorCastConfig(scaling_type=TensorScalingType.DELAYED),
cast_config_grad_output=Float8TensorCastConfig(
scaling_type=TensorScalingType.DELAYED
),
cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
)


Expand Down
20 changes: 7 additions & 13 deletions benchmarks/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from float8_experimental.config import (
Float8LinearConfig,
Float8TensorCastConfig,
TensorScalingType,
)
from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType
from float8_experimental.float8_linear_utils import (
convert_to_float8_training,
linear_requires_sync,
Expand Down Expand Up @@ -217,15 +213,13 @@ def main(
assert model_type in ("linear", "ln_linear", "norm_ffn_norm"), "unsupported"
assert dtype_filter in ("both", "float8", "bfloat16")

scaling_type_input = TensorScalingType(scaling_type_input)
scaling_type_weight = TensorScalingType(scaling_type_weight)
scaling_type_grad_output = TensorScalingType(scaling_type_grad_output)
scaling_type_input = ScalingType(scaling_type_input)
scaling_type_weight = ScalingType(scaling_type_weight)
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
config = Float8LinearConfig(
cast_config_input=Float8TensorCastConfig(scaling_type=scaling_type_input),
cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight),
cast_config_grad_output=Float8TensorCastConfig(
scaling_type=scaling_type_grad_output
),
cast_config_input=CastConfig(scaling_type=scaling_type_input),
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
)
scaling_repr = "_".join(
[
Expand Down
8 changes: 4 additions & 4 deletions float8_experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
# LICENSE file in the root directory of this source tree.
# Lets define a few top level things here
from float8_experimental.config import (
CastConfig,
DelayedScalingConfig,
Float8GemmConfig,
Float8LinearConfig,
Float8TensorCastConfig,
TensorScalingType,
ScalingType,
)
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import (
Expand All @@ -33,10 +33,10 @@
__all__ = [
# configuration
"DelayedScalingConfig",
"TensorScalingType",
"ScalingType",
"Float8GemmConfig",
"Float8LinearConfig",
"Float8TensorCastConfig",
"CastConfig",
# top level UX
"convert_to_float8_training",
"linear_requires_sync",
Expand Down
17 changes: 9 additions & 8 deletions float8_experimental/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,26 @@
from dataclasses import dataclass


class TensorScalingType(enum.Enum):
# TODO(future): consider renaming to ScalingType
class ScalingType(enum.Enum):
DELAYED = "delayed"
DYNAMIC = "dynamic"

def short_str(self):
if self is TensorScalingType.DELAYED:
if self is ScalingType.DELAYED:
return "del"
else:
assert self is TensorScalingType.DYNAMIC
assert self is ScalingType.DYNAMIC
return "dyn"


@dataclass(frozen=True)
class Float8TensorCastConfig:
class CastConfig:
"""
Configuration for casting a single tensor to float8
"""

scaling_type: TensorScalingType = TensorScalingType.DYNAMIC
scaling_type: ScalingType = ScalingType.DYNAMIC


@dataclass(frozen=True)
Expand Down Expand Up @@ -74,9 +75,9 @@ class Float8LinearConfig:
#
# Per-tensor configuration for `input`, `weight`, `grad_output`
#
cast_config_input: Float8TensorCastConfig = Float8TensorCastConfig()
cast_config_weight: Float8TensorCastConfig = Float8TensorCastConfig()
cast_config_grad_output: Float8TensorCastConfig = Float8TensorCastConfig()
cast_config_input: CastConfig = CastConfig()
cast_config_weight: CastConfig = CastConfig()
cast_config_grad_output: CastConfig = CastConfig()

#
# Per-gemm configuration for gemms calculating `output`, `grad_input` and
Expand Down
26 changes: 12 additions & 14 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import torch

from float8_experimental.config import Float8LinearConfig, TensorScalingType
from float8_experimental.config import Float8LinearConfig, ScalingType

from float8_experimental.float8_dynamic_utils import (
cast_to_float8_e4m3_dynamic,
Expand Down Expand Up @@ -215,9 +215,9 @@ def __init__(self, *args, **kwargs):
self.scaling_type_grad_output = config.cast_config_grad_output.scaling_type
# Convenience flag to skip code related to delayed scaling
self.has_any_delayed_scaling = (
self.scaling_type_input is TensorScalingType.DELAYED
or self.scaling_type_weight is TensorScalingType.DELAYED
or self.scaling_type_grad_output is TensorScalingType.DELAYED
self.scaling_type_input is ScalingType.DELAYED
or self.scaling_type_weight is ScalingType.DELAYED
or self.scaling_type_grad_output is ScalingType.DELAYED
)

self.config = config
Expand Down Expand Up @@ -340,7 +340,7 @@ def cast_input_to_float8(
autocast_dtype = torch.get_autocast_gpu_dtype()
input = input.to(autocast_dtype)

if self.scaling_type_input is TensorScalingType.DELAYED:
if self.scaling_type_input is ScalingType.DELAYED:
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
input,
Expand All @@ -361,14 +361,14 @@ def cast_input_to_float8(
gemm_input_role=GemmInputRole.INPUT,
)
else:
assert self.scaling_type_input is TensorScalingType.DYNAMIC
assert self.scaling_type_input is ScalingType.DYNAMIC
input_fp8 = cast_to_float8_e4m3_dynamic(input, self.linear_mm_config)
return input_fp8

def cast_weight_to_float8(
self, weight: torch.Tensor, is_amax_initialized: bool
) -> torch.Tensor:
if self.scaling_type_weight is TensorScalingType.DELAYED:
if self.scaling_type_weight is ScalingType.DELAYED:
if isinstance(self.weight, Float8Tensor): # cast by FSDP
weight_fp8 = self.weight
else:
Expand All @@ -393,7 +393,7 @@ def cast_weight_to_float8(
gemm_input_role=GemmInputRole.WEIGHT,
)
else:
assert self.scaling_type_weight is TensorScalingType.DYNAMIC
assert self.scaling_type_weight is ScalingType.DYNAMIC
if isinstance(self.weight, Float8Tensor): # cast by FSDP
weight_fp8 = self.weight
else:
Expand All @@ -405,7 +405,7 @@ def cast_weight_to_float8(
return weight_fp8

def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
if self.scaling_type_grad_output is TensorScalingType.DELAYED:
if self.scaling_type_grad_output is ScalingType.DELAYED:
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
output = NoopFwToFloat8E5M2Bw.apply(
output,
Expand All @@ -417,7 +417,7 @@ def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
self.linear_mm_config,
)
else:
assert self.scaling_type_grad_output is TensorScalingType.DYNAMIC
assert self.scaling_type_grad_output is ScalingType.DYNAMIC
output = cast_to_float8_e5m2_dynamic_bw(output, self.linear_mm_config)
return output

Expand Down Expand Up @@ -504,17 +504,15 @@ def from_float(
# 2. buffers need to be already created for the delayed scaling version
# of the weight wrapper to be initialized
if config.enable_fsdp_float8_all_gather:
if config.cast_config_weight.scaling_type is TensorScalingType.DYNAMIC:
if config.cast_config_weight.scaling_type is ScalingType.DYNAMIC:
new_mod.weight = torch.nn.Parameter(
WeightWithDynamicFloat8CastTensor(
new_mod.weight,
new_mod.linear_mm_config,
)
)
else:
assert (
config.cast_config_weight.scaling_type is TensorScalingType.DELAYED
)
assert config.cast_config_weight.scaling_type is ScalingType.DELAYED
new_mod.weight = torch.nn.Parameter(
WeightWithDelayedFloat8CastTensor(
new_mod.weight,
Expand Down
8 changes: 4 additions & 4 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from float8_experimental.config import Float8LinearConfig, TensorScalingType
from float8_experimental.config import Float8LinearConfig, ScalingType
from float8_experimental.float8_linear import Float8Linear

from float8_experimental.float8_utils import (
Expand All @@ -27,9 +27,9 @@ def linear_requires_sync(config: Float8LinearConfig):
"""Returns whether the given linear_type requires sync before forward."""
return any(
[
config.cast_config_input.scaling_type is TensorScalingType.DELAYED,
config.cast_config_weight.scaling_type is TensorScalingType.DELAYED,
config.cast_config_grad_output.scaling_type is TensorScalingType.DELAYED,
config.cast_config_input.scaling_type is ScalingType.DELAYED,
config.cast_config_weight.scaling_type is ScalingType.DELAYED,
config.cast_config_grad_output.scaling_type is ScalingType.DELAYED,
]
)

Expand Down
6 changes: 3 additions & 3 deletions float8_experimental/float8_tensor_parallel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import torch.nn as nn
from float8_experimental.config import TensorScalingType
from float8_experimental.config import ScalingType
from float8_experimental.float8_dynamic_utils import (
cast_to_float8_e4m3_dynamic,
cast_to_float8_e5m2_dynamic_bw,
Expand Down Expand Up @@ -28,8 +28,8 @@ def _float8_linear_supports_float8_allgather(m):
# TODO(future): add support for delayed scaling for activations
# and gradients
return (
m.scaling_type_input == TensorScalingType.DYNAMIC
and m.scaling_type_grad_output == TensorScalingType.DYNAMIC
m.scaling_type_input == ScalingType.DYNAMIC
and m.scaling_type_grad_output == ScalingType.DYNAMIC
)


Expand Down
5 changes: 2 additions & 3 deletions float8_experimental/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,12 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
optim.step()
precompute_float8_dynamic_scale_for_fsdp(model)
"""
from float8_experimental.config import TensorScalingType
from float8_experimental.config import ScalingType
from float8_experimental.float8_linear import Float8Linear
from torch.distributed._tensor import DTensor

if any(
isinstance(m, Float8Linear)
and m.scaling_type_weight is TensorScalingType.DELAYED
isinstance(m, Float8Linear) and m.scaling_type_weight is ScalingType.DELAYED
for m in module.modules()
):
raise NotImplementedError("Only supports delayed scaling")
Expand Down
Loading

0 comments on commit 63efec6

Please sign in to comment.