Skip to content

Commit

Permalink
fix comment
Browse files Browse the repository at this point in the history
  • Loading branch information
li126com committed Jul 12, 2024
1 parent 2fe2891 commit 537d470
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 36 deletions.
68 changes: 50 additions & 18 deletions internlm/core/parallel/comm/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def input_hook(

@abstractmethod
def grad_output_hook(
self, grad_output: torch.Tensor, async_op: bool = False, no_communication: bool = False
self,
grad_output: torch.Tensor,
async_op: bool = False,
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
communication for grad_output when backward.
Expand All @@ -82,7 +84,9 @@ def grad_input_hook(self, grad_input: torch.Tensor, async_op: bool = False) -> T

@abstractmethod
def output_hook(
self, output: torch.Tensor, async_op: bool = False, no_communication: bool = False
self,
output: torch.Tensor,
async_op: bool = False,
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
communication for output when forward.
Expand All @@ -95,13 +99,14 @@ class TensorParallelCommunicator(TPCommunicator):
tensor parallel communicator for linear
"""

def __init__(self, process_group: dist.ProcessGroup, role: LinearRole) -> None:
def __init__(self, process_group: dist.ProcessGroup, role: LinearRole, last_block_layer=False) -> None:
assert role in (LinearRole.COLUMN, LinearRole.ROW), f"Unknown linear role: {role}"

self._process_group = process_group
self._role = role

self._save_total_input = False
self.last_block_layer = last_block_layer

def save_total_input(self) -> bool:
return self._save_total_input
Expand All @@ -120,8 +125,7 @@ def input_hook(
def grad_output_hook(
self,
grad_output: torch.Tensor,
async_op: bool = False,
no_communication: bool = False, # pylint: disable=W0613
async_op: bool = False, # pylint: disable=W0613
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
tensor parallel should do nothing for grad_output.
Expand All @@ -138,12 +142,18 @@ def grad_input_hook(self, grad_input: torch.Tensor, async_op: bool = False) -> T
return all_reduce_raw(grad_input, process_group=self._process_group, async_op=async_op)

def output_hook(
self, output: torch.Tensor, async_op: bool = False, no_communication: bool = False
self,
output: torch.Tensor,
async_op: bool = False,
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
all reduce output only for row parallel linear when forward.
"""
if no_communication or dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN:
if (
(self.last_block_layer and gpc.recompute_forward_no_comm)
or dist.get_world_size(self._process_group) <= 1
or self._role == LinearRole.COLUMN
):
return output, DUMMY_HANDLE_CONST

return all_reduce_raw(output, process_group=self._process_group, async_op=async_op)
Expand All @@ -155,14 +165,20 @@ class SequenceParallelCommunicator(TPCommunicator):
"""

def __init__(
self, process_group: dist.ProcessGroup, role: LinearRole, save_total_input_as_activation: bool = False
self,
process_group: dist.ProcessGroup,
role: LinearRole,
save_total_input_as_activation: bool = False,
last_block_layer=False,
) -> None:
assert role in (LinearRole.COLUMN, LinearRole.ROW), f"Unknown linear role: {role}"

self._process_group = process_group
self._role = role

self._save_total_input = save_total_input_as_activation
self.last_block_layer = last_block_layer
self.no_communication = False

def save_total_input(self) -> bool:
return self._save_total_input
Expand All @@ -189,12 +205,19 @@ def input_hook(
return all_gather_raw(_input, process_group=self._process_group, async_op=async_op, gather_dim=_GATHER_DIM)

def grad_output_hook(
self, grad_output: torch.Tensor, async_op: bool = False, no_communication: bool = False
self,
grad_output: torch.Tensor,
async_op: bool = False,
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
all gather grad_output only for row parallel linear when backward.
"""
if no_communication or dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN:
if (
(self.last_block_layer and self.no_communication)
or dist.get_world_size(self._process_group) <= 1
or self._role == LinearRole.COLUMN
):
self.no_communication = False
return grad_output, DUMMY_HANDLE_CONST

return all_gather_raw(grad_output, process_group=self._process_group, async_op=async_op, gather_dim=_GATHER_DIM)
Expand All @@ -211,12 +234,19 @@ def grad_input_hook(self, grad_input: torch.Tensor, async_op: bool = False) -> T
)

def output_hook(
self, output: torch.Tensor, async_op: bool = False, no_communication: bool = False
self,
output: torch.Tensor,
async_op: bool = False,
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
reduce scatter output only for row parallel linear when forward.
"""
if no_communication or dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN:
self.no_communication = gpc.recompute_forward_no_comm
if (
(self.last_block_layer and self.no_communication)
or dist.get_world_size(self._process_group) <= 1
or self._role == LinearRole.COLUMN
):
return output, DUMMY_HANDLE_CONST

return reduce_scatter_raw(output, process_group=self._process_group, async_op=async_op, reduce_dim=_REDUCE_DIM)
Expand All @@ -236,8 +266,7 @@ def __init__(self, parallel_mode: ParallelMode, retain_out_sharded: bool = True)
def grad_output_hook(
self,
grad_output: torch.Tensor,
async_op: bool = False,
no_communication: bool = False, # pylint: disable=W0613
async_op: bool = False, # pylint: disable=W0613
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
split grad_output if retain_out_sharded is False.
Expand All @@ -248,7 +277,9 @@ def grad_output_hook(
return _split(grad_output, parallel_mode=self._parallel_mode, dim=-1), DUMMY_HANDLE_CONST

def output_hook(
self, output: torch.Tensor, async_op: bool = False, no_communication: bool = False # pylint: disable=W0613
self,
output: torch.Tensor,
async_op: bool = False, # pylint: disable=W0613
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
all gather output for head layer if retain_out_sharded is False.
Expand Down Expand Up @@ -280,8 +311,7 @@ def __init__(
def grad_output_hook(
self,
grad_output: torch.Tensor,
async_op: bool = False,
no_communication: bool = False, # pylint: disable=W0613
async_op: bool = False, # pylint: disable=W0613
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
split grad_output if retain_out_sharded is False.
Expand All @@ -293,7 +323,9 @@ def grad_output_hook(

# rewrite ouput communication hook
def output_hook(
self, output: torch.Tensor, async_op: bool = False, no_communication: bool = False # pylint: disable=W0613
self,
output: torch.Tensor,
async_op: bool = False, # pylint: disable=W0613
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
all gather output for head layer if retain_out_sharded is False.
Expand Down
24 changes: 6 additions & 18 deletions internlm/model/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,10 @@ def forward(
bias: Optional[torch.Tensor],
communicator: TPCommunicator,
return_residual=False,
no_communication=False,
):
ctx.compute_weight_gradient = weight.requires_grad
ctx.return_residual = return_residual
ctx.communicator = communicator
ctx.no_communication = no_communication

if torch.is_autocast_enabled():
x = x.to(dtype=torch.get_autocast_gpu_dtype())
Expand Down Expand Up @@ -79,7 +77,7 @@ def forward(

# parallel strategy-specific communication callback 2.
# see more details in the communicator for different parallel strategies.
output, _ = communicator.output_hook(output, async_op=False, no_communication=no_communication)
output, _ = communicator.output_hook(output, async_op=False)

saved_x = None if ctx.compute_weight_gradient is False else total_x if communicator.save_total_input() else x
ctx.save_for_backward(saved_x, weight)
Expand All @@ -93,9 +91,7 @@ def backward(ctx, grad_output, *args):

# parallel strategy-specific communication callback 3.
# see more details in the communicator for different parallel strategies.
grad_output, _ = communicator.grad_output_hook(
grad_output, no_communication=ctx.no_communication, async_op=False
)
grad_output, _ = communicator.grad_output_hook(grad_output, async_op=False)
grad_output = grad_output.contiguous()

if ctx.return_residual:
Expand Down Expand Up @@ -268,7 +264,6 @@ def fused_dense_func(
module: Optional[nn.Module] = None,
bias: Optional[torch.Tensor] = None,
return_residual: bool = False,
no_communication=False,
):
if communicator.communication_mode() == "wp":
return WPFusedDenseFunc.apply(
Expand All @@ -286,7 +281,6 @@ def fused_dense_func(
bias,
communicator,
return_residual,
no_communication,
)


Expand Down Expand Up @@ -349,19 +343,15 @@ def __init__(
else:
super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)

self.last_block_layer = False

def forward(self, input: torch.Tensor) -> torch.Tensor: # pylint: disable=W0622
_class_name = self.__class__.__name__
assert self._communicator is not None, f"{_class_name} should register with a communicator first."
no_communication = bool(gpc.recompute_forward_no_comm and self.last_block_layer)
return fused_dense_func(
input,
self.weight,
communicator=self._communicator,
module=self,
bias=self.bias,
no_communication=no_communication,
)


Expand Down Expand Up @@ -417,7 +407,6 @@ def __init__(
multiple_of: int = 1,
device: torch.device = None,
dtype: torch.dtype = None,
layer_name: str = "default",
) -> None:
if in_features % multiple_of:
raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}")
Expand All @@ -434,9 +423,6 @@ def __init__(
split_mode="row",
)

if layer_name == "w2":
self.last_block_layer = True


class ScaleColumnParallelLinear(ParallelLinearWithCommExt):
"""
Expand Down Expand Up @@ -602,15 +588,17 @@ def new_linear(
dtype,
)
elif split_mode == "row":
return RowParallelLinear(
linear = RowParallelLinear(
in_features,
out_features,
bias,
multiple_of,
device,
dtype,
layer_name=name,
)
if name == "w2":
setattr(linear, "last_block_layer", True)
return linear
else:
err_msg = (
f"Parallel strategies for linear is unsupported, which is named as {name}.\n"
Expand Down
21 changes: 21 additions & 0 deletions internlm/train/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,15 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]):
)
_head_communicator = HeadTensorParallelCommunicator(ParallelMode.TENSOR, _retain_out_sharded)
_embedding_communicator = EmbbedingTensorParallelCommunicator(ParallelMode.TENSOR)

# for tp recompute communication optimization, sign last block layer
for row_parallel_linear in _submodule_filter(model, RowParallelLinear):
if getattr(row_parallel_linear, "last_block_layer", False):
row_parallel_linear.register_communicator(
TensorParallelCommunicator(
process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.ROW, last_block_layer=True
)
)
# sequence parallel
if gpc.config.parallel.tensor.mode in ("msp", "fsp"):
save_total_input_as_activation = gpc.config.parallel.tensor.mode == "msp"
Expand All @@ -296,6 +305,18 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]):
)
)

# for tp recompute communication optimization, sign last block layer
for row_parallel_linear in _submodule_filter(model, RowParallelLinear):
if getattr(row_parallel_linear, "last_block_layer", False):
row_parallel_linear.register_communicator(
SequenceParallelCommunicator(
gpc.get_group(ParallelMode.TENSOR),
role=LinearRole.ROW,
save_total_input_as_activation=save_total_input_as_activation,
last_block_layer=True,
)
)

_head_communicator = HeadSequenceParallelCommunicator(
ParallelMode.TENSOR, _retain_out_sharded, save_total_input_as_activation
)
Expand Down

0 comments on commit 537d470

Please sign in to comment.