From 216691b32a115ccdd98b69d9ee89129ac07ac581 Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Thu, 23 May 2024 18:28:47 +0800 Subject: [PATCH] isp --- .../solver/optimizer/hybrid_zero_optim.py | 37 +++++++++---------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 20485678..b7eb9c30 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -126,15 +126,12 @@ def __init__( self.require_grad_sync = True # if process_group is none, will use the default one - self.dp_pg = gpc.get_group(ParallelMode.DATA) self.zero_pg = gpc.get_group(ParallelMode.ZERO1) self._local_rank = gpc.get_local_rank(ParallelMode.DATA) self._world_size = gpc.get_world_size(ParallelMode.DATA) self.zero_local_rank = gpc.get_local_rank(ParallelMode.ZERO1) self.zero_world_size = gpc.get_world_size(ParallelMode.ZERO1) - self._broadcast_parallel_mode = ParallelMode.ZERO1 - # extra dp # This group is used to sync moe param, dp_world_size = moe_duplicates * extra_dp_size. # Non moe param will be sync by global dp pg, moe param will be sync by extra dp pg. @@ -169,9 +166,6 @@ def __init__( # gradient clipping self._clip_grad_norm = clip_grad_norm - self.padding_grad = torch.zeros([32], dtype=gpc.config.model.dtype, device=get_current_device()) - self.padding_tensor = torch.zeros([32], dtype=gpc.config.model.dtype, device=get_current_device()) - # master weights copy self._master_weights = master_weights @@ -428,12 +422,15 @@ def _create_master_param_current_rank(self, param_list): def _run_reduction(self): for group_id in range(self.num_param_groups): current_bucket = self._bucket_store[group_id] + dp_parallel_mode = current_bucket.get_dp_parallel_mode() + reduce_group = gpc.get_group(dp_parallel_mode) + world_size = gpc.get_world_size(dp_parallel_mode) if current_bucket.num_elements_in_bucket() > 0: current_bucket.build_grad_in_bucket() if self.moe_extra_dp_pg is None: flat_grads = current_bucket.get_flatten_grad() - flat_grads /= self._world_size + flat_grads /= world_size else: # record moe and non moe param moe_list = [] @@ -460,7 +457,7 @@ def _run_reduction(self): for grad_list in non_moe_grad_list: non_moe_flat_grads.append(_flatten_dense_tensors(grad_list)) non_moe_flat_grads = _flatten_dense_tensors(non_moe_flat_grads) - non_moe_flat_grads /= self._world_size + non_moe_flat_grads /= world_size if len(moe_grad_list) > 0: moe_flat_grads = [] @@ -496,7 +493,7 @@ def _run_reduction(self): if not self._partition_grads: if self.moe_extra_dp_pg is None: - dist.all_reduce(flat_grads, group=self.dp_pg) + dist.all_reduce(flat_grads, group=reduce_group) if flat_grads.dtype != grad_dtype: flat_grads = flat_grads.to(grad_dtype) @@ -508,9 +505,9 @@ def _run_reduction(self): else: # sync non moe param in global dp group if len(non_moe_grad_list) > 0: - dist.all_reduce(non_moe_flat_grads, group=self.dp_pg) + dist.all_reduce(non_moe_flat_grads, group=reduce_group) flat_grads_per_rank = non_moe_flat_grads.split( - non_moe_flat_grads.numel() // self._world_size + non_moe_flat_grads.numel() // self.zero_world_size ) self._update_unpartitoned_grad(non_moe_grad_list, flat_grads_per_rank, group_id) @@ -522,18 +519,18 @@ def _run_reduction(self): else: if self.moe_extra_dp_pg is None: - flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size)) + flat_grads_list = list(flat_grads.split(len(flat_grads) // world_size)) recieved_grad = torch.zeros_like(flat_grads_list[0]) - dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) + dist.reduce_scatter(recieved_grad, flat_grads_list, group=reduce_group) if recieved_grad.dtype != grad_dtype: recieved_grad = recieved_grad.to(grad_dtype) - grad_in_bucket_current_rank = current_bucket.get_grad()[self._local_rank] + grad_in_bucket_current_rank = current_bucket.get_grad()[self.zero_local_rank] self._update_partitoned_grad(grad_in_bucket_current_rank, recieved_grad, group_id, 1) else: # categorize moe and non moe param - grad_in_bucket_current_rank = current_bucket.get_grad()[self._local_rank] + grad_in_bucket_current_rank = current_bucket.get_grad()[self.zero_local_rank] moe_grad_in_bucket_current_rank = [] non_moe_grad_in_bucket_current_rank = [] for idx, grad in enumerate(grad_in_bucket_current_rank): @@ -544,10 +541,10 @@ def _run_reduction(self): if len(non_moe_grad_list) > 0: flat_grads_list = list( - non_moe_flat_grads.split(len(non_moe_flat_grads) // self._world_size) + non_moe_flat_grads.split(len(non_moe_flat_grads) // world_size) ) recieved_grad = torch.zeros_like(flat_grads_list[0]) - dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) + dist.reduce_scatter(recieved_grad, flat_grads_list, group=reduce_group) self._update_partitoned_grad( non_moe_grad_in_bucket_current_rank, recieved_grad, @@ -773,7 +770,7 @@ def step(self, closure=None): # InternEvo group_name = self.param_groups[group_id]["name"] if "name" in self.param_groups[group_id] else "default" group_name = f"{group_id}_{group_name}" - total_norms[group_name] = self._compute_norm(group_id=group_id, gradients=working_grads, parameters=param_group, zero_mode=self._broadcast_parallel_mode) + total_norms[group_name] = self._compute_norm(group_id=group_id, gradients=working_grads, parameters=param_group, zero_mode=ParallelMode.ZERO1) self._grad_store.reset_grads_by_group_id(group_id) @@ -822,7 +819,7 @@ def step(self, closure=None): address=gpc.config.monitor.alert.feishu_alert_address, message="Overflow occurs, please check it.", ) - self._grad_store._averaged_gradients = dict() + self._grad_store._grads_of_params = dict() self.zero_grad() return False, total_norms @@ -833,7 +830,7 @@ def step(self, closure=None): address=gpc.config.monitor.alert.feishu_alert_address, message="Nan grad norm occurs, please check it.", ) - self._grad_store._averaged_gradients = dict() + self._grad_store._grads_of_params = dict() self.zero_grad() return False, total_norms