From 6fbd1de1d323da4acdf8a305c6e0fb372fa5512b Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Tue, 24 Sep 2024 10:44:41 +0800 Subject: [PATCH] temp --- internlm/core/scheduler/pipeline_scheduler.py | 52 ++++++++++++------- 1 file changed, 34 insertions(+), 18 deletions(-) diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index a2e08914..58221e38 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -911,6 +911,16 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T # Run 1F1B in steady state. for i in range(num_1f1b_micropairs): + + if i > 0: + input_obj = None + if not gpc.is_first_rank(ParallelMode.PIPELINE): + input_obj = comm.recv_forward( + forward_recv_shapes, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors, + ) + # Perform forward computation output_obj, moe_loss = self._forward_step( engine, @@ -946,24 +956,30 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T input_obj_grad = self._backward_step(engine, i, input_obj, output_obj, output_obj_grad, moe_loss) - if i == (num_1f1b_micropairs - 1): - input_obj = None - if not gpc.is_first_rank(ParallelMode.PIPELINE): - comm.send_backward( - input_obj_grad, - scatter_gather_tensors=self.scatter_gather_tensors, - ) - else: - if gpc.is_first_rank(ParallelMode.PIPELINE): - input_obj = None - else: - input_obj = comm.send_backward_recv_forward( - input_obj_grad, - forward_recv_shapes, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors, - ) - + # if i == (num_1f1b_micropairs - 1): + # input_obj = None + # if not gpc.is_first_rank(ParallelMode.PIPELINE): + # comm.send_backward( + # input_obj_grad, + # scatter_gather_tensors=self.scatter_gather_tensors, + # ) + # else: + # if gpc.is_first_rank(ParallelMode.PIPELINE): + # input_obj = None + # else: + # input_obj = comm.send_backward_recv_forward( + # input_obj_grad, + # forward_recv_shapes, + # dtype=self.dtype, + # scatter_gather_tensors=self.scatter_gather_tensors, + # ) + + if not gpc.is_first_rank(ParallelMode.PIPELINE): + comm.send_backward( + input_obj_grad, + scatter_gather_tensors=self.scatter_gather_tensors, + ) + WeightGradStore.flush() if i >= gpc.get_local_rank(ParallelMode.PIPELINE): WeightGradStore.pop()