Skip to content

Commit

Permalink
temp
Browse files Browse the repository at this point in the history
  • Loading branch information
li126com committed Sep 24, 2024
1 parent f2f8e5b commit 6fbd1de
Showing 1 changed file with 34 additions and 18 deletions.
52 changes: 34 additions & 18 deletions internlm/core/scheduler/pipeline_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,6 +911,16 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T

# Run 1F1B in steady state.
for i in range(num_1f1b_micropairs):

if i > 0:
input_obj = None
if not gpc.is_first_rank(ParallelMode.PIPELINE):
input_obj = comm.recv_forward(
forward_recv_shapes,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors,
)

# Perform forward computation
output_obj, moe_loss = self._forward_step(
engine,
Expand Down Expand Up @@ -946,24 +956,30 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T

input_obj_grad = self._backward_step(engine, i, input_obj, output_obj, output_obj_grad, moe_loss)

if i == (num_1f1b_micropairs - 1):
input_obj = None
if not gpc.is_first_rank(ParallelMode.PIPELINE):
comm.send_backward(
input_obj_grad,
scatter_gather_tensors=self.scatter_gather_tensors,
)
else:
if gpc.is_first_rank(ParallelMode.PIPELINE):
input_obj = None
else:
input_obj = comm.send_backward_recv_forward(
input_obj_grad,
forward_recv_shapes,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors,
)

# if i == (num_1f1b_micropairs - 1):
# input_obj = None
# if not gpc.is_first_rank(ParallelMode.PIPELINE):
# comm.send_backward(
# input_obj_grad,
# scatter_gather_tensors=self.scatter_gather_tensors,
# )
# else:
# if gpc.is_first_rank(ParallelMode.PIPELINE):
# input_obj = None
# else:
# input_obj = comm.send_backward_recv_forward(
# input_obj_grad,
# forward_recv_shapes,
# dtype=self.dtype,
# scatter_gather_tensors=self.scatter_gather_tensors,
# )

if not gpc.is_first_rank(ParallelMode.PIPELINE):
comm.send_backward(
input_obj_grad,
scatter_gather_tensors=self.scatter_gather_tensors,
)

WeightGradStore.flush()
if i >= gpc.get_local_rank(ParallelMode.PIPELINE):
WeightGradStore.pop()
Expand Down

0 comments on commit 6fbd1de

Please sign in to comment.