Skip to content

Commit

Permalink
isp
Browse files Browse the repository at this point in the history
  • Loading branch information
li126com committed May 23, 2024
1 parent efaf43d commit 216691b
Showing 1 changed file with 17 additions and 20 deletions.
37 changes: 17 additions & 20 deletions internlm/solver/optimizer/hybrid_zero_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,15 +126,12 @@ def __init__(
self.require_grad_sync = True

# if process_group is none, will use the default one
self.dp_pg = gpc.get_group(ParallelMode.DATA)
self.zero_pg = gpc.get_group(ParallelMode.ZERO1)
self._local_rank = gpc.get_local_rank(ParallelMode.DATA)
self._world_size = gpc.get_world_size(ParallelMode.DATA)
self.zero_local_rank = gpc.get_local_rank(ParallelMode.ZERO1)
self.zero_world_size = gpc.get_world_size(ParallelMode.ZERO1)

self._broadcast_parallel_mode = ParallelMode.ZERO1

# extra dp
# This group is used to sync moe param, dp_world_size = moe_duplicates * extra_dp_size.
# Non moe param will be sync by global dp pg, moe param will be sync by extra dp pg.
Expand Down Expand Up @@ -169,9 +166,6 @@ def __init__(
# gradient clipping
self._clip_grad_norm = clip_grad_norm

self.padding_grad = torch.zeros([32], dtype=gpc.config.model.dtype, device=get_current_device())
self.padding_tensor = torch.zeros([32], dtype=gpc.config.model.dtype, device=get_current_device())

# master weights copy
self._master_weights = master_weights

Expand Down Expand Up @@ -428,12 +422,15 @@ def _create_master_param_current_rank(self, param_list):
def _run_reduction(self):
for group_id in range(self.num_param_groups):
current_bucket = self._bucket_store[group_id]
dp_parallel_mode = current_bucket.get_dp_parallel_mode()
reduce_group = gpc.get_group(dp_parallel_mode)
world_size = gpc.get_world_size(dp_parallel_mode)
if current_bucket.num_elements_in_bucket() > 0:
current_bucket.build_grad_in_bucket()

if self.moe_extra_dp_pg is None:
flat_grads = current_bucket.get_flatten_grad()
flat_grads /= self._world_size
flat_grads /= world_size
else:
# record moe and non moe param
moe_list = []
Expand All @@ -460,7 +457,7 @@ def _run_reduction(self):
for grad_list in non_moe_grad_list:
non_moe_flat_grads.append(_flatten_dense_tensors(grad_list))
non_moe_flat_grads = _flatten_dense_tensors(non_moe_flat_grads)
non_moe_flat_grads /= self._world_size
non_moe_flat_grads /= world_size

if len(moe_grad_list) > 0:
moe_flat_grads = []
Expand Down Expand Up @@ -496,7 +493,7 @@ def _run_reduction(self):

if not self._partition_grads:
if self.moe_extra_dp_pg is None:
dist.all_reduce(flat_grads, group=self.dp_pg)
dist.all_reduce(flat_grads, group=reduce_group)
if flat_grads.dtype != grad_dtype:
flat_grads = flat_grads.to(grad_dtype)

Expand All @@ -508,9 +505,9 @@ def _run_reduction(self):
else:
# sync non moe param in global dp group
if len(non_moe_grad_list) > 0:
dist.all_reduce(non_moe_flat_grads, group=self.dp_pg)
dist.all_reduce(non_moe_flat_grads, group=reduce_group)
flat_grads_per_rank = non_moe_flat_grads.split(
non_moe_flat_grads.numel() // self._world_size
non_moe_flat_grads.numel() // self.zero_world_size
)
self._update_unpartitoned_grad(non_moe_grad_list, flat_grads_per_rank, group_id)

Expand All @@ -522,18 +519,18 @@ def _run_reduction(self):

else:
if self.moe_extra_dp_pg is None:
flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
flat_grads_list = list(flat_grads.split(len(flat_grads) // world_size))
recieved_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
dist.reduce_scatter(recieved_grad, flat_grads_list, group=reduce_group)

if recieved_grad.dtype != grad_dtype:
recieved_grad = recieved_grad.to(grad_dtype)

grad_in_bucket_current_rank = current_bucket.get_grad()[self._local_rank]
grad_in_bucket_current_rank = current_bucket.get_grad()[self.zero_local_rank]
self._update_partitoned_grad(grad_in_bucket_current_rank, recieved_grad, group_id, 1)
else:
# categorize moe and non moe param
grad_in_bucket_current_rank = current_bucket.get_grad()[self._local_rank]
grad_in_bucket_current_rank = current_bucket.get_grad()[self.zero_local_rank]
moe_grad_in_bucket_current_rank = []
non_moe_grad_in_bucket_current_rank = []
for idx, grad in enumerate(grad_in_bucket_current_rank):
Expand All @@ -544,10 +541,10 @@ def _run_reduction(self):

if len(non_moe_grad_list) > 0:
flat_grads_list = list(
non_moe_flat_grads.split(len(non_moe_flat_grads) // self._world_size)
non_moe_flat_grads.split(len(non_moe_flat_grads) // world_size)
)
recieved_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
dist.reduce_scatter(recieved_grad, flat_grads_list, group=reduce_group)
self._update_partitoned_grad(
non_moe_grad_in_bucket_current_rank,
recieved_grad,
Expand Down Expand Up @@ -773,7 +770,7 @@ def step(self, closure=None):
# InternEvo
group_name = self.param_groups[group_id]["name"] if "name" in self.param_groups[group_id] else "default"
group_name = f"{group_id}_{group_name}"
total_norms[group_name] = self._compute_norm(group_id=group_id, gradients=working_grads, parameters=param_group, zero_mode=self._broadcast_parallel_mode)
total_norms[group_name] = self._compute_norm(group_id=group_id, gradients=working_grads, parameters=param_group, zero_mode=ParallelMode.ZERO1)

self._grad_store.reset_grads_by_group_id(group_id)

Expand Down Expand Up @@ -822,7 +819,7 @@ def step(self, closure=None):
address=gpc.config.monitor.alert.feishu_alert_address,
message="Overflow occurs, please check it.",
)
self._grad_store._averaged_gradients = dict()
self._grad_store._grads_of_params = dict()
self.zero_grad()
return False, total_norms

Expand All @@ -833,7 +830,7 @@ def step(self, closure=None):
address=gpc.config.monitor.alert.feishu_alert_address,
message="Nan grad norm occurs, please check it.",
)
self._grad_store._averaged_gradients = dict()
self._grad_store._grads_of_params = dict()
self.zero_grad()
return False, total_norms

Expand Down

0 comments on commit 216691b

Please sign in to comment.