-
Notifications
You must be signed in to change notification settings - Fork 20
[DISCUSSION] fix float8 all-gather in FSDP2 + TP: DTensor(WeightWithDynamicFloat8CastTensor) #326
base: main
Are you sure you want to change the base?
Changes from 9 commits
b5cad8d
a6b8913
272e85b
097ceed
b6ebf8d
2eaa51b
969f91f
f475c40
cc763ce
7fbb867
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -81,6 +81,8 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: | |
torch.ops.aten.as_strided.default, | ||
torch.ops.aten._to_copy.default, | ||
torch.ops.aten._pin_memory.default, | ||
torch.ops.aten.split.Tensor, | ||
torch.ops.aten.clone.default, | ||
} | ||
|
||
|
||
|
@@ -188,12 +190,22 @@ def fsdp_post_all_gather( | |
*, | ||
out: Optional[torch.Tensor] = None, | ||
): | ||
from torch.distributed._tensor import DTensor | ||
|
||
(data,) = all_gather_outputs | ||
(scale,) = metadata | ||
if out is not None: | ||
assert isinstance(out, Float8Tensor), f"{type(out)}" | ||
out._scale = scale | ||
return | ||
if isinstance(out, Float8Tensor): | ||
out._scale = scale | ||
elif isinstance(out, DTensor) and isinstance( | ||
out._local_tensor, Float8Tensor | ||
): | ||
out._local_tensor._scale = scale | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure about this change yet. just want to have someting sketchy to discuss first |
||
else: | ||
raise RuntimeError( | ||
f"out must be a Float8Tensor or DTensor with Float8Tensor local tensor, but got {type(out)}" | ||
) | ||
return out | ||
return Float8Tensor( | ||
data, | ||
scale, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,13 @@ | |
set_enable_fsdp_fp8_all_gather, | ||
) | ||
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy | ||
from torch.distributed._tensor import DTensor | ||
from torch.distributed._tensor import ( | ||
distribute_tensor, | ||
DTensor, | ||
init_device_mesh, | ||
Shard, | ||
) | ||
from torch.distributed.device_mesh import DeviceMesh | ||
from torch.testing._internal.common_cuda import TEST_CUDA | ||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu | ||
from torch.testing._internal.common_fsdp import ( | ||
|
@@ -516,5 +522,50 @@ def test_delayed_scaling_inplace_update(self): | |
self.assertNotEqual(fp8_amax_w_old.item(), m_fp8.fp8_amax_w.item()) | ||
|
||
|
||
class Test2DFloat8MultiProcess(FSDPTest, TestFloat8Common): | ||
@property | ||
def world_size(self) -> int: | ||
return min(torch.cuda.device_count(), 4) | ||
|
||
def init_global_mesh(self) -> DeviceMesh: | ||
dp_size = 2 if self.world_size > 2 else 1 | ||
return init_device_mesh( | ||
"cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp") | ||
) | ||
|
||
@skip_if_lt_x_gpu(4) | ||
def test_fsdp_tp( | ||
self, | ||
): | ||
enable_fsdp_fp8_all_gather = True | ||
scaling_type_w = TensorScalingType.DYNAMIC | ||
global_mesh = self.init_global_mesh() | ||
_, tp_mesh = global_mesh["dp"], global_mesh["tp"] | ||
module = self.init_transformer(weight_tying=False).cuda() | ||
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): | ||
swap_linear_with_float8_linear(module, scaling_type_w=scaling_type_w) | ||
|
||
# "attention.wq": Float8ColwiseParallel | ||
colwise_param = distribute_tensor( | ||
module.layers[0].attention.wq.weight, tp_mesh, [Shard(0)] | ||
) | ||
self.assertTrue( | ||
isinstance(colwise_param, DTensor) | ||
and isinstance( | ||
colwise_param._local_tensor, WeightWithDynamicFloat8CastTensor | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. editted: without this PR, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you explain where the bf16 came from? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. correct my word to be accurate: without this PR, |
||
) | ||
) | ||
# "attention.wo": Float8RowwiseParallel(output_layouts=Shard(1)), | ||
rowwise_param = distribute_tensor( | ||
module.layers[0].attention.wo.weight, tp_mesh, [Shard(1)] | ||
) | ||
self.assertTrue( | ||
isinstance(rowwise_param, DTensor) | ||
and isinstance( | ||
rowwise_param._local_tensor, WeightWithDynamicFloat8CastTensor | ||
) | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aten.split
is fromtorch.chunk
, when calling fromdistribute_tensor
during TP initeditted: @awgu curious if you still remember the reason to return
Tensor
from torch.chunk instead ofWeightWithDynamicFloat8CastTensor
. Is it for padding? any concerns if I prefertorch.chunk
to returningWeightWithDynamicFloat8CastTensor
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought that dtype and whether is
WeightWithDynamicFloat8CastTensor
are orthogonal. Do you mean the latter (whether isWeightWithDynamicFloat8CastTensor
or not?I think originally I only added the ops that I saw I needed. Adding
aten.split
andaten.clone
seems okay to me.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
exactly,
WeightWithDynamicFloat8CastTensor
or not is the key. I edited my previous comments to say right now torch.chunk returnsTensor
changing
torch.chunk
affects both TP and FSDP2. will double check FSDP2 after the change