Skip to content

Commit

Permalink
fix speed monitor for TP
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Dec 3, 2024
1 parent 95c41ff commit a0005cd
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion src/olmo_core/train/callbacks/speed_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch

from olmo_core.distributed.utils import get_world_size
from olmo_core.nn.transformer import Transformer

from .callback import Callback
Expand Down Expand Up @@ -33,6 +34,7 @@ class SpeedMonitorCallback(Callback):
_batch_load_time: float = 0.0
_step_tokens: int = 0
_step_seq_len: int = 0
_parallel_degree: int = 1

def _get_num_flops_per_token(self, seq_len: int) -> Optional[int]:
if self.num_flops_per_token is not None:
Expand All @@ -45,6 +47,11 @@ def _get_num_flops_per_token(self, seq_len: int) -> Optional[int]:
def pre_train(self):
self._first_step = True

if self.trainer.dp_process_group is not None:
self._parallel_degree = get_world_size() // get_world_size(
self.trainer.dp_process_group
)

if self.device_peak_flops is None and self.trainer.device.type == "cuda":
device_name = torch.cuda.get_device_name(self.trainer.device)
if self.trainer.autocast_precision == torch.bfloat16:
Expand Down Expand Up @@ -73,7 +80,7 @@ def pre_step(self, batch: Dict[str, Any]):
# unusually long.
return

self._step_tokens = batch["input_ids"].numel()
self._step_tokens = batch["input_ids"].numel() // self._parallel_degree
self._step_seq_len = batch["input_ids"].shape[1]
self._total_steps += 1
self._total_tokens += self._step_tokens
Expand Down

0 comments on commit a0005cd

Please sign in to comment.