Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

[DISCUSSION] fix float8 all-gather in FSDP2 + TP: DTensor(WeightWithDynamicFloat8CastTensor) #326

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
15 changes: 7 additions & 8 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,12 +357,9 @@ def cast_w_to_float8(
)
else:
assert self.scaling_type_w is TensorScalingType.DYNAMIC
if isinstance(self.weight, Float8Tensor): # cast by FSDP
w_fp8 = self.weight
else:
w_fp8 = cast_to_float8_e4m3_dynamic(
self.weight, self.linear_mm_config, gemm_input_role=GemmInputRole.W
)
w_fp8 = cast_to_float8_e4m3_dynamic(
self.weight, self.linear_mm_config, gemm_input_role=GemmInputRole.W
)
return w_fp8

def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -407,8 +404,10 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.has_any_delayed_scaling:
self.float8_pre_forward(input)

x_fp8 = self.cast_x_to_float8(input, self.is_amax_initialized)
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)
with torch.profiler.record_function("cast_x_to_float8"):
x_fp8 = self.cast_x_to_float8(input, self.is_amax_initialized)
with torch.profiler.record_function("cast_w_to_float8"):
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)

y = torch.matmul(x_fp8, w_fp8.t())

Expand Down
41 changes: 23 additions & 18 deletions float8_experimental/float8_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,13 @@ def _prepare_input_fn(
input_tensor = DTensor.from_local(
input_tensor, device_mesh, input_layouts, run_check=False
)

input_tensor = cast_to_float8_e4m3_dynamic(
input_tensor,
mod.linear_mm_config,
gemm_input_role=GemmInputRole.X,
) # DTensor(Float8Tensor)

with torch.profiler.record_function("colwise_cast_to_float8_e4m3_dynamic"):
input_tensor = cast_to_float8_e4m3_dynamic(
input_tensor,
mod.linear_mm_config,
gemm_input_role=GemmInputRole.X,
) # DTensor(Float8Tensor)

# transform the input layouts to the desired layouts of ColwiseParallel
if input_layouts != desired_input_layouts:
Expand All @@ -67,7 +68,8 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me
) # DTensor(torch.Tensor)

# fwd noop bwd cast to DTensor(Float8Tensor)
outputs = cast_to_float8_e5m2_dynamic_bw(outputs, mod.linear_mm_config)
with torch.profiler.record_function("colwise_cast_to_float8_e5m2_dynamic_bw"):
outputs = cast_to_float8_e5m2_dynamic_bw(outputs, mod.linear_mm_config)

# back to local tensor
return outputs.to_local() if use_local_output else outputs
Expand Down Expand Up @@ -98,11 +100,12 @@ def _prepare_input_fn(
input_tensor, device_mesh, input_layouts, run_check=False
)

input_tensor = cast_to_float8_e4m3_dynamic(
input_tensor,
mod.linear_mm_config,
gemm_input_role=GemmInputRole.X,
) # DTensor(Float8Tensor)
with torch.profiler.record_function("rowwise_cast_to_float8_e4m3_dynamic"):
input_tensor = cast_to_float8_e4m3_dynamic(
input_tensor,
mod.linear_mm_config,
gemm_input_role=GemmInputRole.X,
) # DTensor(Float8Tensor)

if input_layouts != desired_input_layouts:
input_tensor = input_tensor.redistribute(
Expand All @@ -119,7 +122,8 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me
outputs = outputs.redistribute(placements=output_layouts, async_op=True)

# fwd noop bwd cast to DTensor(Float8Tensor)
outputs = cast_to_float8_e5m2_dynamic_bw(outputs, mod.linear_mm_config)
with torch.profiler.record_function("rowwise_cast_to_float8_e5m2_dynamic_bw"):
outputs = cast_to_float8_e5m2_dynamic_bw(outputs, mod.linear_mm_config)

# back to local tensor if use_local_output is True
return outputs.to_local() if use_local_output else outputs
Expand Down Expand Up @@ -196,11 +200,12 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout):
input, mesh, (input_layout,), run_check=False
)

dt_inp = cast_to_float8_e4m3_dynamic(
dt_inp,
self.linear_mm_config,
gemm_input_role=GemmInputRole.X,
) # DTensor(Float8Tensor)
with torch.profiler.record_function("prepareinput_cast_to_float8_e4m3_dynamic"):
dt_inp = cast_to_float8_e4m3_dynamic(
dt_inp,
self.linear_mm_config,
gemm_input_role=GemmInputRole.X,
) # DTensor(Float8Tensor)
if desired_layout is not None and input_layout != desired_layout:
dt_inp = dt_inp.redistribute(placements=(desired_layout,))

Expand Down
22 changes: 19 additions & 3 deletions float8_experimental/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
torch.ops.aten.as_strided.default,
torch.ops.aten._to_copy.default,
torch.ops.aten._pin_memory.default,
torch.ops.aten.split.Tensor,
Copy link
Contributor Author

@weifengpy weifengpy Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aten.split is from torch.chunk, when calling from distribute_tensor during TP init

editted: @awgu curious if you still remember the reason to return Tensor from torch.chunk instead of WeightWithDynamicFloat8CastTensor. Is it for padding? any concerns if I prefer torch.chunk to returning WeightWithDynamicFloat8CastTensor ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awgu curious if you still remember the reason to return bf16 from torch.chunk.

I thought that dtype and whether is WeightWithDynamicFloat8CastTensor are orthogonal. Do you mean the latter (whether is WeightWithDynamicFloat8CastTensor or not?

I think originally I only added the ops that I saw I needed. Adding aten.split and aten.clone seems okay to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whether is WeightWithDynamicFloat8CastTensor or not

exactly, WeightWithDynamicFloat8CastTensor or not is the key. I edited my previous comments to say right now torch.chunk returns Tensor

I think originally I only added the ops that I saw I needed

changing torch.chunk affects both TP and FSDP2. will double check FSDP2 after the change

torch.ops.aten.clone.default,
}


Expand Down Expand Up @@ -138,6 +140,10 @@ def unwrap(t):
WeightWithDynamicFloat8CastTensor, unwrap, (args, kwargs or {})
)
out = func(*args, **kwargs)
if func is torch.ops.aten.split.Tensor:
# if func is torch.ops.aten.clone.default:
if torch.distributed.get_rank() == 0:
print(f"dispatched {func=}", flush=True)
if func not in _ops_to_preserve_subclass:
return out
return pytree.tree_map_only(
Expand Down Expand Up @@ -188,12 +194,22 @@ def fsdp_post_all_gather(
*,
out: Optional[torch.Tensor] = None,
):
from torch.distributed._tensor import DTensor

(data,) = all_gather_outputs
(scale,) = metadata
if out is not None:
assert isinstance(out, Float8Tensor), f"{type(out)}"
out._scale = scale
return
if isinstance(out, Float8Tensor):
out._scale = scale
elif isinstance(out, DTensor) and isinstance(
out._local_tensor, Float8Tensor
):
out._local_tensor._scale = scale
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure about this change yet. just want to have someting sketchy to discuss first

else:
raise RuntimeError(
f"out must be a Float8Tensor or DTensor(_local_tensor=Float8Tensor), but got {out}"
)
return out
return Float8Tensor(
data,
scale,
Expand Down
146 changes: 145 additions & 1 deletion test/test_fsdp2/test_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,22 @@
check_parity_no_mp,
set_enable_fsdp_fp8_all_gather,
)
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
PrepareModuleInput,
RowwiseParallel,
SequenceParallel,
)
from torch.distributed._tensor import Replicate, Shard
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._tensor import DTensor
from torch.distributed._tensor import (
distribute_tensor,
DTensor,
init_device_mesh,
Shard,
)
from torch.distributed.device_mesh import DeviceMesh
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
Expand Down Expand Up @@ -516,5 +530,135 @@ def test_delayed_scaling_inplace_update(self):
self.assertNotEqual(fp8_amax_w_old.item(), m_fp8.fp8_amax_w.item())


class Test2DFloat8MultiProcess(FSDPTest, TestFloat8Common):
@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 4)

def init_global_mesh(self) -> DeviceMesh:
dp_size = 2 if self.world_size > 2 else 1
return init_device_mesh(
"cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp")
)

def parallelize(
self, module: "Transformer", device_mesh: DeviceMesh, use_seq_parallel: bool
) -> nn.Module:
assert isinstance(module, Transformer), f"Requires Transformer but got {module}"
module_tp = parallelize_module(module, device_mesh, {
"tok_embeddings": RowwiseParallel(input_layouts=Replicate(), output_layouts=Shard(1)),
"pos_embeddings": RowwiseParallel(input_layouts=Replicate(), output_layouts=Shard(0)),
"norm": SequenceParallel(),
})
for layer_id, transformer_block in model.layers.items():
layer_plan = {


"attention.wq": Float8ColwiseParallel(),
"attention.wk": Float8ColwiseParallel(),
"attention.wv": Float8ColwiseParallel(),
"attention.wo": Float8RowwiseParallel(output_layouts=Shard(1)),

"feed_forward": PrepareFloat8ModuleInput(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": Float8ColwiseParallel(),
"feed_forward.w2": Float8RowwiseParallel(output_layouts=Shard(1)),
"feed_forward.w3": Float8ColwiseParallel(),
}

# Adjust attention module to use the local number of heads
attn_layer = transformer_block.attention
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()

parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
parallelize_plan=layer_plan,
)
# Parallelize the attention and feed forward submodules.
for layer in module_tp.layers:
layer_parallelize_plan = {}
layer_parallelize_plan["attention"] = PrepareFloat8ModuleInput(
input_layouts=(Shard(1), None),
desired_input_layouts=(Replicate(), None),
)
# shard the RMSNorms
layer_parallelize_plan["attention_norm"] = SequenceParallel()
layer_parallelize_plan["ffn_norm"] = SequenceParallel()
layer_parallelize_plan["attention.wq"] = Float8ColwiseParallel()
layer_parallelize_plan["attention.wk"] = Float8ColwiseParallel()
layer_parallelize_plan["attention.wv"] = Float8ColwiseParallel()
layer_parallelize_plan["attention.wo"] = Float8RowwiseParallel(output_layouts=Shard(1))

layer_parallelize_plan["feed_forward.w1"] = (
ColwiseParallel(input_layouts=Shard(1))
if use_seq_parallel
else ColwiseParallel()
)
layer_parallelize_plan["feed_forward.w2"] = Float8RowwiseParallel(output_layouts=Shard(1))

parallelize_module(layer, device_mesh, layer_parallelize_plan)

# Parallelize the output submodule. If weight tying is enabled, we need to
# make sure output.weight is sharded consistently as tok_embeddings.weight,
# at the cost of the all_reduce operation using RowwiseParallel.
output_parallelize_plan = (
ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Replicate(),
)
if use_seq_parallel
else ColwiseParallel(output_layouts=Replicate())
)
parallelize_module(module_tp.output, device_mesh, output_parallelize_plan)

# Manually set output.weight so that parameters and gradients are shared.
if module_tp.model_args.weight_tying:
module_tp.output.weight = module_tp.tok_embeddings.weight

return module_tp

@skip_if_lt_x_gpu(4)
def test_fsdp_tp(
self,
):
enable_fsdp_fp8_all_gather = True
scaling_type_w = TensorScalingType.DYNAMIC
global_mesh = self.init_global_mesh()
_, tp_mesh = global_mesh["dp"], global_mesh["tp"]
model = self.init_transformer(weight_tying=False).cuda()
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
swap_linear_with_float8_linear(model, scaling_type_w=scaling_type_w)
model.
loss_parallel = True




# "attention.wq": Float8ColwiseParallel
colwise_param = distribute_tensor(
model.layers[0].attention.wq.weight, tp_mesh, [Shard(0)]
)
self.assertTrue(
isinstance(colwise_param, DTensor)
and isinstance(
colwise_param._local_tensor, WeightWithDynamicFloat8CastTensor
Copy link
Contributor Author

@weifengpy weifengpy Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

editted: without this PR, torch.chunk returns bf16 tensor. FSDP2 happens after TP, thus only see Float8Linear(weight=DTensor(_local_tensor=Tensor))
with this PR, torch.chunk returns WeightWithDynamicFloat8CastTensor

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain where the bf16 came from?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct my word to be accurate: without this PR, torch.chunk returns plain Tensor (can be fp32 or bf16) instead of WeightWithDynamicFloat8CastTensor

)
)
# "attention.wo": Float8RowwiseParallel(output_layouts=Shard(1)),
rowwise_param = distribute_tensor(
model.layers[0].attention.wo.weight, tp_mesh, [Shard(1)]
)
self.assertTrue(
isinstance(rowwise_param, DTensor)
and isinstance(
rowwise_param._local_tensor, WeightWithDynamicFloat8CastTensor
)
)


if __name__ == "__main__":
run_tests()