Skip to content

Commit

Permalink
fix ci
Browse files Browse the repository at this point in the history
  • Loading branch information
li126com committed Oct 28, 2024
1 parent eb987b3 commit b34879a
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 16 deletions.
5 changes: 4 additions & 1 deletion internlm/core/parallel/comm/isp.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,10 @@ def after_backward(self, scheduler, inputs_grad) -> None: # pylint: disable=W06
if self._isp_communicator and self._isp_communicator.overlap:
self._zero_optim.accumulate_left_grads_after_backward()

if not self._zero_optim.skip_grad_reduce:
if (
getattr(gpc.config.parallel["pipeline"], "mode", "1F1B").upper() in ["ZBV", "ZBH1"]
and not self._zero_optim.skip_grad_reduce
):
self._zero_optim.reduce_left_grads_after_backward()

def post_helper_func(self, scheduler, outputs, label) -> None: # pylint: disable=W0613
Expand Down
2 changes: 1 addition & 1 deletion internlm/core/parallel/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def partition_uniform(num_items: int, pipeline_parallel_size: int, num_chunks: i
if chunk_size == 0:
raise ValueError("Some nodes in Pipeline have no requests")

if gpc.config.parallel["pipeline"]["mode"] == "ZBV" and idx == 1:
if getattr(gpc.config.parallel["pipeline"], "mode", "1F1B").upper() == "ZBV" and idx == 1:
for p in range(pipeline_parallel_size - 1, -1, -1):
st = base_idx
base_idx += chunk_size + ((pipeline_parallel_size - p - 1) >= left)
Expand Down
8 changes: 5 additions & 3 deletions internlm/initialize/initialize_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,16 @@ def _data_preparation_func(_data, _label):

return _data, _label

pp_mode = getattr(gpc.config.parallel["pipeline"], "mode", "1F1B").upper()

if gpc.is_using_parallel_mode(ParallelMode.PIPELINE):
gpc.config.NUM_MICRO_BATCHES = gpc.config.data.micro_num
tensor_shape = get_tensor_shape()
use_interleaved = (
hasattr(gpc.config, "model")
and hasattr(gpc.config.model, "num_chunks")
and gpc.config.model.num_chunks > 1
and gpc.config.parallel["pipeline"]["mode"] == "1F1B"
and pp_mode == "1F1B"
)
scatter_gather = gpc.is_initialized(ParallelMode.TENSOR)
if use_interleaved:
Expand All @@ -120,7 +122,7 @@ def _data_preparation_func(_data, _label):
scheduler_hooks=scheduler_hooks,
communication_overlap=communication_overlap,
)
elif gpc.config.parallel["pipeline"]["mode"] == "ZBH1":
elif pp_mode == "ZBH1":
scheduler = ZeroBubblePipelineScheduler(
data_process_func=_data_preparation_func,
num_microbatches=gpc.config.NUM_MICRO_BATCHES,
Expand All @@ -130,7 +132,7 @@ def _data_preparation_func(_data, _label):
scheduler_hooks=scheduler_hooks,
optimizer=optimizer,
)
elif gpc.config.parallel["pipeline"]["mode"] == "ZBV":
elif pp_mode == "ZBV":
scheduler = ZeroBubblePipelineVShapeScheduler(
num_microbatches=gpc.config.NUM_MICRO_BATCHES,
num_chunks=gpc.config.model.num_chunks,
Expand Down
6 changes: 3 additions & 3 deletions internlm/model/modules/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ def __init__(
self.w1 = new_linear(
"w1", in_features, hidden_features, bias, device=device, dtype=dtype, is_expert=is_expert
)
self.w3 = new_linear(
"w3", in_features, hidden_features, bias, device=device, dtype=dtype, is_expert=is_expert
)
self.w2 = new_linear(
"w2", hidden_features, out_features, bias, device=device, dtype=dtype, is_expert=is_expert
)
self.w3 = new_linear(
"w3", in_features, hidden_features, bias, device=device, dtype=dtype, is_expert=is_expert
)

def forward(self, x):
if not self.mlp_layer_fusion:
Expand Down
6 changes: 5 additions & 1 deletion tests/test_core/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ def exam_pipeline_parallel(args):
first_output = output_list[0]
for i in range(1, 10):
assert torch.equal(first_output, output_list[i])
print(
f"idx {i} pass: micro_num={micro_num}, num_chunks={num_chunks}, overlap={interleaved_overlap}",
flush=True,
)

# check output
torch_output = torch_model(input_ids=torch_xs) # pylint: disable=E1102
Expand All @@ -167,7 +171,7 @@ def exam_pipeline_parallel(args):
loose_close(torch_loss, loss[0], dtype=dtype)


@pytest.mark.parametrize("micro_num", [4, 8, 16])
@pytest.mark.parametrize("micro_num", [8, 16])
@pytest.mark.parametrize("num_chunks", [1, 2, 4])
@pytest.mark.parametrize("interleaved_overlap", [True, False])
def test_pipeline_parallel(micro_num, num_chunks, interleaved_overlap):
Expand Down
13 changes: 6 additions & 7 deletions tests/test_training/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def train(
load_ckpt: bool = False,
model_type: str = "INTERNLM",
optimizer_ver: str = "v1",
zero_bubble: bool = False,
pp_mode: str = "1F1B",
):
# initialize distributed environment
config = Config.from_file(CONFIG_FILE_PATH)
Expand Down Expand Up @@ -97,14 +97,13 @@ def train(

# update parallel config
config.parallel.tensor = dict(size=tp_size, mode=tp_mode)
if zero_bubble:
if pp_mode == "ZBH1":
config.hybrid_zero_optimizer.overlap_sync_grad = False
config.parallel.pipeline = dict(size=pp_size, zero_bubble=True)
else:
config.parallel.pipeline = dict(size=pp_size)

config.parallel.pipeline = dict(size=pp_size, mode=pp_mode)
config.parallel.weight = dict(size=wp_size, overlap=True)
if interleaved is True:
config.parallel.pipeline = dict(size=pp_size, interleaved_overlap=True)
config.parallel.pipeline = dict(size=pp_size, interleaved_overlap=True, mode=pp_mode)
config.model.num_chunks = num_chunks

if "use_packed_dataset" not in config.data:
Expand Down Expand Up @@ -379,7 +378,7 @@ def test_training_loss_with_dp4_pp2():
@pytest.mark.training_8GPU_4DP2PP_ZB
def test_training_loss_with_dp4_pp2_zero_bubble():
# model training
train(dp_size=4, pp_size=2, zero_bubble=True)
train(dp_size=4, pp_size=2, pp_mode="ZBH1")

# print loss value
print(f"cur_loss_list: {cur_loss_list}", flush=True)
Expand Down

0 comments on commit b34879a

Please sign in to comment.