Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
Update base for Update on "bring back torch.autograd.Function for flo…
Browse files Browse the repository at this point in the history
…at8 matmul"

Summary:

This is a redo of
#316

With upcoming support of scaling granularities other than tensorwise,
we need a good way to control which gemm kernel to call and how to scale
the input tensors in fwd and bwd. A `torch.autograd.Function` override
is the cleanest way to do that, and in 2024 this now works with
`torch.compile`.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
  • Loading branch information
vkuzo committed Jul 25, 2024
1 parent dbd5d02 commit 224cfdf
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ This is theoretically the most performant recipe as it minimizes memory reads.
from float8_experimental import (
convert_to_float8_training,
sync_float8_amax_and_scale_history,
TensorScalingType,
ScalingType,
)

# create model
Expand All @@ -95,13 +95,13 @@ m = Model(...)
# gated with config.enable_amax_init and
# config.enable_pre_and_post_forward are needed for
# autocast + compile + FSDP + float8 to work
from float8_experimental import Float8LinearConfig, TensorScalingType, Float8TensorCastConfig
from float8_experimental import Float8LinearConfig, ScalingType, CastConfig
config = Float8LinearConfig(
enable_amax_init = False, # only needed for autocast + compile + FSDP + float8 delayed
enable_pre_and_post_forward, False # only needed for autocast + compile + FSDP + float8 delayed
cast_config_input=Float8TensorCastConfig(scaling_type=TensorScalingType.DELAYED),
cast_config_weight=Float8TensorCastConfig(scaling_type=TensorScalingType.DELAYED),
cast_config_grad_output=Float8TensorCastConfig(scaling_type=TensorScalingType.DELAYED),
cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
)

# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying scaling
Expand Down

0 comments on commit 224cfdf

Please sign in to comment.