Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Dec 3, 2024
1 parent a0005cd commit bed869f
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 20 deletions.
8 changes: 2 additions & 6 deletions src/examples/llama/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,8 @@ def main(run_name: str, overrides: List[str]):
)
optim = config.optim.build(model)
dataset = config.dataset.build()
data_loader = config.data_loader.build(
dataset, dp_process_group=get_dp_process_group(world_mesh)
)
trainer = config.trainer.build(
model, optim, data_loader, dp_process_group=get_dp_process_group(world_mesh)
)
data_loader = config.data_loader.build(dataset, mesh=world_mesh)
trainer = config.trainer.build(model, optim, data_loader, mesh=world_mesh)

# Save config to W&B and each checkpoint dir.
config_dict = config.as_config_dict()
Expand Down
13 changes: 13 additions & 0 deletions src/olmo_core/data/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@
import torch
import torch.distributed as dist
import torch.utils.data
from torch.distributed import DeviceMesh

from ..aliases import PathOrStr
from ..config import Config
from ..distributed.parallel import get_dp_process_group
from ..distributed.utils import barrier, get_fs_local_rank, get_rank, get_world_size
from ..exceptions import OLMoConfigurationError
from ..utils import get_default_device, roundrobin, threaded_generator
Expand Down Expand Up @@ -913,14 +915,25 @@ def build(
dataset: NumpyDatasetBase,
*,
collator: Optional[DataCollator] = None,
mesh: Optional[DeviceMesh] = None,
dp_process_group: Optional[dist.ProcessGroup] = None,
) -> NumpyDataLoaderBase:
"""
Construct the :class:`NumpyDataLoaderBase`.
:param dataset: The dataset.
:param mesh: An optional ``DeviceMesh`` that defines the data parallel dimensions. Ideally
you should create this mesh using :func:`~olmo_core.distributed.parallel.build_device_mesh()`
or equivalently :meth:`olmo_core.nn.transformer.TransformerConfig.build_mesh()`.
Alternatively you can pass the ``dp_process_group`` instead.
:param dp_process_group: The data parallel process group.
"""
if self.work_dir is not None and not dataset.work_dir_set:
dataset.work_dir = Path(self.work_dir)

if dp_process_group is None and mesh is not None:
dp_process_group = get_dp_process_group(mesh)

dataset.prepare()

data_loader = NumpyDataLoaderBase.wrap_numpy_dataset(
Expand Down
9 changes: 2 additions & 7 deletions src/olmo_core/internal/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
VSLCurriculumConfig,
VSLCurriculumType,
)
from olmo_core.distributed.parallel import get_dp_process_group
from olmo_core.distributed.utils import get_local_rank
from olmo_core.float8 import Float8Config
from olmo_core.launch.beaker import BeakerLaunchConfig
Expand Down Expand Up @@ -276,12 +275,8 @@ def train(config: ExperimentConfig):
)
optim = config.optim.build(model)
dataset = config.dataset.build()
data_loader = config.data_loader.build(
dataset, dp_process_group=get_dp_process_group(world_mesh)
)
trainer = config.trainer.build(
model, optim, data_loader, dp_process_group=get_dp_process_group(world_mesh)
)
data_loader = config.data_loader.build(dataset, mesh=world_mesh)
trainer = config.trainer.build(model, optim, data_loader, mesh=world_mesh)

# Record the config to W&B/Comet and each checkpoint dir.
config_dict = config.as_config_dict()
Expand Down
9 changes: 2 additions & 7 deletions src/olmo_core/internal/model_ladder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from olmo_core.config import Config, StrEnum
from olmo_core.data import NumpyDataLoaderConfig, NumpyDatasetConfig
from olmo_core.distributed.parallel import get_dp_process_group
from olmo_core.distributed.utils import get_local_rank
from olmo_core.launch.beaker import BeakerLaunchConfig
from olmo_core.model_ladder import ModelLadder, ModelSize
Expand Down Expand Up @@ -77,12 +76,8 @@ def run(self, config: LadderRunConfig):
)
optim = config.optim.build(model)
dataset = config.dataset.build()
data_loader = config.data_loader.build(
dataset, dp_process_group=get_dp_process_group(world_mesh)
)
trainer = config.trainer.build(
model, optim, data_loader, dp_process_group=get_dp_process_group(world_mesh)
)
data_loader = config.data_loader.build(dataset, mesh=world_mesh)
trainer = config.trainer.build(model, optim, data_loader, mesh=world_mesh)

# Record the config to W&B/Comet and each checkpoint dir.
config_dict = config.as_config_dict()
Expand Down
12 changes: 12 additions & 0 deletions src/olmo_core/train/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import DeviceMesh
from torch.optim import Optimizer

from ..config import Config, DType
from ..data import DataLoaderBase
from ..distributed.parallel import get_dp_process_group
from ..exceptions import OLMoConfigurationError
from ..io import is_url
from ..utils import get_default_device
Expand Down Expand Up @@ -73,6 +75,8 @@ def build(
model: nn.Module,
optim: Optimizer,
data_loader: DataLoaderBase,
*,
mesh: Optional[DeviceMesh] = None,
dp_process_group: Optional[dist.ProcessGroup] = None,
checkpointer_pg: Optional[dist.ProcessGroup] = None,
) -> Trainer:
Expand All @@ -82,9 +86,17 @@ def build(
:param model: The model to train.
:param optim: The optimizer to use.
:param data_loader: The data loader to train on.
:param mesh: An optional ``DeviceMesh`` that defines the data parallel dimensions. Ideally
you should create this mesh using :func:`~olmo_core.distributed.parallel.build_device_mesh()`
or equivalently :meth:`olmo_core.nn.transformer.TransformerConfig.build_mesh()`.
Alternatively you can pass the ``dp_process_group`` instead.
:param dp_process_group: The data parallel process group.
"""
kwargs = self.as_dict(exclude_none=True, recurse=False)

if dp_process_group is None and mesh is not None:
dp_process_group = get_dp_process_group(mesh)

checkpointer = Checkpointer(
save_overwrite=kwargs["save_overwrite"],
process_group=checkpointer_pg,
Expand Down

0 comments on commit bed869f

Please sign in to comment.