Skip to content

Commit

Permalink
v4
Browse files Browse the repository at this point in the history
  • Loading branch information
li126com committed Oct 28, 2024
1 parent 0607da8 commit eb987b3
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 18 deletions.
3 changes: 2 additions & 1 deletion configs/7B_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,15 @@
1. size: int, the size of pipeline parallel.
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler,
defaults to False.
3. mode: str, the pipeline parallel mode, should be in ['1f1b', 'zbh1', 'zbv']. The defalut is 1f1b.
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
"""
parallel = dict(
zero1=dict(size=-1),
tensor=dict(size=2, mode="isp"),
pipeline=dict(size=1, interleaved_overlap=True),
pipeline=dict(size=1, interleaved_overlap=True, mode="1f1b"),
weight=dict(size=2, overlap=True),
)

Expand Down
3 changes: 1 addition & 2 deletions configs/7B_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,14 @@
1. size: int, the size of pipeline parallel (Default is 1F1B).
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler,
defaults to False.
3. zero_bubble: bool, enable/disable zero bubble pipeline parallelism (ZB-H1), defaults to False.
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
"""
parallel = dict(
zero1=dict(size=-1),
tensor=dict(size=1, mode="mtp"),
pipeline=dict(size=1, interleaved_overlap=True, zero_bubble=False),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=True),
)

Expand Down
4 changes: 2 additions & 2 deletions internlm/core/scheduler/comm/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def _communicate(
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
# internlm_accelerator.synchronize()
internlm_accelerator.synchronize()

if recv_prev and recv_prev_split:
if isinstance(tensor_recv_prev, torch.Tensor):
Expand Down Expand Up @@ -287,7 +287,7 @@ def _communicate_async(
for req in reqs: # pylint: disable=E0601
req.wait()
# To protect against race condition when using batch_isend_irecv().
# internlm_accelerator.synchronize()
internlm_accelerator.synchronize()

if recv_prev and recv_prev_split:
if isinstance(tensor_recv_prev, torch.Tensor):
Expand Down
26 changes: 13 additions & 13 deletions internlm/train/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,19 +498,19 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList], isp_communicato
use_apex_adam=use_apex_adam,
)

# if (
# zero_cfg.overlap_sync_grad
# and gpc.is_using_parallel_mode(ParallelMode.PIPELINE)
# and gpc.is_pipeline_first_stage() is False
# ):
# # When pipeline parallelism is enabled, we prefer to only enable optimizer
# # gradient communication overlap in the first stage, to avoid amplifying
# # the communication overhead stage by stage in cases where the optimizer
# # communication overhead is greater than the compute overhead.
# # For pipeline stages except the first, even if overlap is not enabled,
# # their gradient synchronization overhead can be well hidden by
# # the inherent bubbles of pipeline parallelism.
# zero_cfg.overlap_sync_grad = False
if (
zero_cfg.overlap_sync_grad
and gpc.is_using_parallel_mode(ParallelMode.PIPELINE)
and gpc.is_pipeline_first_stage() is False
):
# When pipeline parallelism is enabled, we prefer to only enable optimizer
# gradient communication overlap in the first stage, to avoid amplifying
# the communication overhead stage by stage in cases where the optimizer
# communication overhead is greater than the compute overhead.
# For pipeline stages except the first, even if overlap is not enabled,
# their gradient synchronization overhead can be well hidden by
# the inherent bubbles of pipeline parallelism.
zero_cfg.overlap_sync_grad = False

if zero_cfg.overlap_sync_param:
param_bcast_sync_handler = ParamAsyncBcastHandler(ParallelMode.ZERO1, model, isp_communicator)
Expand Down

0 comments on commit eb987b3

Please sign in to comment.