Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

[DISCUSSION] fix float8 all-gather in FSDP2 + TP: DTensor(WeightWithDynamicFloat8CastTensor) #326

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
18 changes: 15 additions & 3 deletions float8_experimental/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
torch.ops.aten.as_strided.default,
torch.ops.aten._to_copy.default,
torch.ops.aten._pin_memory.default,
torch.ops.aten.split.Tensor,
Copy link
Contributor Author

@weifengpy weifengpy Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aten.split is from torch.chunk, when calling from distribute_tensor during TP init

editted: @awgu curious if you still remember the reason to return Tensor from torch.chunk instead of WeightWithDynamicFloat8CastTensor. Is it for padding? any concerns if I prefer torch.chunk to returning WeightWithDynamicFloat8CastTensor ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awgu curious if you still remember the reason to return bf16 from torch.chunk.

I thought that dtype and whether is WeightWithDynamicFloat8CastTensor are orthogonal. Do you mean the latter (whether is WeightWithDynamicFloat8CastTensor or not?

I think originally I only added the ops that I saw I needed. Adding aten.split and aten.clone seems okay to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whether is WeightWithDynamicFloat8CastTensor or not

exactly, WeightWithDynamicFloat8CastTensor or not is the key. I edited my previous comments to say right now torch.chunk returns Tensor

I think originally I only added the ops that I saw I needed

changing torch.chunk affects both TP and FSDP2. will double check FSDP2 after the change

torch.ops.aten.clone.default,
}


Expand Down Expand Up @@ -188,12 +190,22 @@ def fsdp_post_all_gather(
*,
out: Optional[torch.Tensor] = None,
):
from torch.distributed._tensor import DTensor

(data,) = all_gather_outputs
(scale,) = metadata
if out is not None:
assert isinstance(out, Float8Tensor), f"{type(out)}"
out._scale = scale
return
if isinstance(out, Float8Tensor):
out._scale = scale
elif isinstance(out, DTensor) and isinstance(
out._local_tensor, Float8Tensor
):
out._local_tensor._scale = scale
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure about this change yet. just want to have someting sketchy to discuss first

else:
raise RuntimeError(
f"out must be a Float8Tensor or DTensor with Float8Tensor local tensor, but got {type(out)}"
)
return out
return Float8Tensor(
data,
scale,
Expand Down
53 changes: 52 additions & 1 deletion test/test_fsdp2/test_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@
set_enable_fsdp_fp8_all_gather,
)
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._tensor import DTensor
from torch.distributed._tensor import (
distribute_tensor,
DTensor,
init_device_mesh,
Shard,
)
from torch.distributed.device_mesh import DeviceMesh
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
Expand Down Expand Up @@ -516,5 +522,50 @@ def test_delayed_scaling_inplace_update(self):
self.assertNotEqual(fp8_amax_w_old.item(), m_fp8.fp8_amax_w.item())


class Test2DFloat8MultiProcess(FSDPTest, TestFloat8Common):
@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 4)

def init_global_mesh(self) -> DeviceMesh:
dp_size = 2 if self.world_size > 2 else 1
return init_device_mesh(
"cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp")
)

@skip_if_lt_x_gpu(4)
def test_fsdp_tp(
self,
):
enable_fsdp_fp8_all_gather = True
scaling_type_w = TensorScalingType.DYNAMIC
global_mesh = self.init_global_mesh()
_, tp_mesh = global_mesh["dp"], global_mesh["tp"]
module = self.init_transformer(weight_tying=False).cuda()
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
swap_linear_with_float8_linear(module, scaling_type_w=scaling_type_w)

# "attention.wq": Float8ColwiseParallel
colwise_param = distribute_tensor(
module.layers[0].attention.wq.weight, tp_mesh, [Shard(0)]
)
self.assertTrue(
isinstance(colwise_param, DTensor)
and isinstance(
colwise_param._local_tensor, WeightWithDynamicFloat8CastTensor
Copy link
Contributor Author

@weifengpy weifengpy Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

editted: without this PR, torch.chunk returns bf16 tensor. FSDP2 happens after TP, thus only see Float8Linear(weight=DTensor(_local_tensor=Tensor))
with this PR, torch.chunk returns WeightWithDynamicFloat8CastTensor

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain where the bf16 came from?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct my word to be accurate: without this PR, torch.chunk returns plain Tensor (can be fp32 or bf16) instead of WeightWithDynamicFloat8CastTensor

)
)
# "attention.wo": Float8RowwiseParallel(output_layouts=Shard(1)),
rowwise_param = distribute_tensor(
module.layers[0].attention.wo.weight, tp_mesh, [Shard(1)]
)
self.assertTrue(
isinstance(rowwise_param, DTensor)
and isinstance(
rowwise_param._local_tensor, WeightWithDynamicFloat8CastTensor
)
)


if __name__ == "__main__":
run_tests()
Loading