-
Notifications
You must be signed in to change notification settings - Fork 20
[DISCUSSION] fix float8 all-gather in FSDP2 + TP: DTensor(WeightWithDynamicFloat8CastTensor) #326
base: main
Are you sure you want to change the base?
Conversation
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
self.assertTrue( | ||
isinstance(colwise_param, DTensor) | ||
and isinstance( | ||
colwise_param._local_tensor, WeightWithDynamicFloat8CastTensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
editted: without this PR, torch.chunk
returns bf16 tensor. FSDP2 happens after TP, thus only see Float8Linear(weight=DTensor(_local_tensor=Tensor))
with this PR, torch.chunk
returns WeightWithDynamicFloat8CastTensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain where the bf16 came from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
correct my word to be accurate: without this PR, torch.chunk
returns plain Tensor
(can be fp32 or bf16) instead of WeightWithDynamicFloat8CastTensor
@@ -81,6 +81,8 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: | |||
torch.ops.aten.as_strided.default, | |||
torch.ops.aten._to_copy.default, | |||
torch.ops.aten._pin_memory.default, | |||
torch.ops.aten.split.Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aten.split
is from torch.chunk
, when calling from distribute_tensor
during TP init
editted: @awgu curious if you still remember the reason to return Tensor
from torch.chunk instead of WeightWithDynamicFloat8CastTensor
. Is it for padding? any concerns if I prefer torch.chunk
to returning WeightWithDynamicFloat8CastTensor
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@awgu curious if you still remember the reason to return bf16 from torch.chunk.
I thought that dtype and whether is WeightWithDynamicFloat8CastTensor
are orthogonal. Do you mean the latter (whether is WeightWithDynamicFloat8CastTensor
or not?
I think originally I only added the ops that I saw I needed. Adding aten.split
and aten.clone
seems okay to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
whether is WeightWithDynamicFloat8CastTensor or not
exactly, WeightWithDynamicFloat8CastTensor
or not is the key. I edited my previous comments to say right now torch.chunk returns Tensor
I think originally I only added the ops that I saw I needed
changing torch.chunk
affects both TP and FSDP2. will double check FSDP2 after the change
elif isinstance(out, DTensor) and isinstance( | ||
out._local_tensor, Float8Tensor | ||
): | ||
out._local_tensor._scale = scale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure about this change yet. just want to have someting sketchy to discuss first
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
draft this PR for discussion, before having something landable
we see 2 problems in float8 all-gather FSDP2 + TP
weight
, but expect all-reduce only forinput
crux is how we dispatch
torch.chunk
, which is called fromdistribute_tensor
for TP inittorch.chunk
returnsTensor
. FSDP2 happens after TP, thus only seeFloat8Linear(weight=DTensor(_local_tensor=Tensor))
torch.chunk
returnsWeightWithDynamicFloat8CastTensor
profiler trace without this PR: AR (all-reduce) for input -> AG (all-gather) -> 4 ARs for wq,k,v,o -> 1 AR for input. 4 ARs for wq,k,v,o should not happen if we precompute amax/scales for
model.parameters()
afteropt.step()