Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for tensor parallelism and add OLMo2-26B model config / train script #117

Merged
merged 43 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
0ef2a8a
Add OLMo2-26B config
epwalsh Nov 27, 2024
d4c65ac
make model wider, shorter
epwalsh Nov 27, 2024
3d23a2d
finalize config
epwalsh Nov 27, 2024
d4ac791
Merge branch 'main' into epwalsh/olmo2-13B
epwalsh Nov 27, 2024
32c520c
update
epwalsh Nov 27, 2024
20eca35
Add max and abs to selective ops AC
epwalsh Nov 28, 2024
07ae12f
Add methods to apply TP to various modules
epwalsh Nov 28, 2024
ef5e695
ensure we don't f-up applying float8 to output again
epwalsh Nov 28, 2024
eb710a1
only print config from local rank 0
epwalsh Dec 2, 2024
4ebd618
fix
epwalsh Dec 2, 2024
365f179
fix?
epwalsh Dec 2, 2024
2aeed48
fix?
epwalsh Dec 2, 2024
9f31745
fix
epwalsh Dec 2, 2024
d37d45d
fix
epwalsh Dec 2, 2024
9a6dc33
more fixes
epwalsh Dec 2, 2024
ed6a02e
add test for tensor parallel model
epwalsh Dec 2, 2024
17359c0
fixes
epwalsh Dec 2, 2024
01634ff
expand test
epwalsh Dec 2, 2024
527c443
fix GQA attention
epwalsh Dec 2, 2024
7e8cbf2
updates
epwalsh Dec 2, 2024
8ecc741
update train example to allow tensor parallel
epwalsh Dec 2, 2024
b89236e
updates
epwalsh Dec 2, 2024
9fec486
fix log
epwalsh Dec 2, 2024
8da6f8e
more fixes
epwalsh Dec 2, 2024
b411c2b
disable async tp by default
epwalsh Dec 2, 2024
c01a374
log
epwalsh Dec 3, 2024
b6a0981
fix
epwalsh Dec 3, 2024
6348fb5
update size
epwalsh Dec 3, 2024
de20020
try next nightly
epwalsh Dec 3, 2024
2863fe9
comment out
epwalsh Dec 3, 2024
5858777
try flash, no compile
epwalsh Dec 3, 2024
739aa89
fix
epwalsh Dec 3, 2024
3dfdedf
try compile again
epwalsh Dec 3, 2024
355f61a
adjust hidden size, NCCL log level
epwalsh Dec 3, 2024
6e581db
ok try newer nightly again
epwalsh Dec 3, 2024
4e1a2c3
ensure mock batch is the same for each DP rank
epwalsh Dec 3, 2024
81adb75
try something else
epwalsh Dec 3, 2024
95c41ff
clean up
epwalsh Dec 3, 2024
a0005cd
fix speed monitor for TP
epwalsh Dec 3, 2024
bed869f
clean up
epwalsh Dec 3, 2024
c1f7d44
fix lint
epwalsh Dec 3, 2024
d552a3a
changelog
epwalsh Dec 3, 2024
8ba7f58
try compile again
epwalsh Dec 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added

- Added support for tensor parallelism. See the `TransformerConfig` class for usage.

## [v1.7.0](https://github.com/allenai/OLMo-core/releases/tag/v1.7.0) - 2024-11-27

### Added
Expand Down
17 changes: 12 additions & 5 deletions src/examples/llama/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
TokenizerConfig,
)
from olmo_core.distributed.parallel import DataParallelType
from olmo_core.distributed.utils import init_hybrid_shard_mesh
from olmo_core.nn.transformer import TransformerConfig, TransformerDataParallelConfig
from olmo_core.optim import AdamWConfig, CosWithWarmup, OptimGroupOverride
from olmo_core.train import (
Expand Down Expand Up @@ -59,6 +58,8 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig:
model_config = TransformerConfig.llama2_271M(
vocab_size=tokenizer_config.padded_vocab_size(), # a little bigger than actual vocab size to make it a multiple of 128
compile=True,
fused_ops=False,
use_flash=False,
dp_config=TransformerDataParallelConfig(
name=DataParallelType.fsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32
),
Expand Down Expand Up @@ -178,16 +179,22 @@ def main(run_name: str, overrides: List[str]):
# Set RNG states on all devices.
seed_all(config.init_seed)

device = get_default_device()

# Build the world mesh, if needed.
world_mesh = config.model.build_mesh(device=device)

# Build components.
model = config.model.build(
init_device="meta",
device=get_default_device(),
dp_mesh=init_hybrid_shard_mesh(num_replicas=2),
device=device,
max_seq_len=config.dataset.sequence_length,
mesh=world_mesh,
)
optim = config.optim.build(model)
dataset = config.dataset.build()
data_loader = config.data_loader.build(dataset)
trainer = config.trainer.build(model, optim, data_loader)
data_loader = config.data_loader.build(dataset, mesh=world_mesh)
trainer = config.trainer.build(model, optim, data_loader, mesh=world_mesh)

# Save config to W&B and each checkpoint dir.
config_dict = config.as_config_dict()
Expand Down
2 changes: 0 additions & 2 deletions src/examples/ngpt/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
TokenizerConfig,
)
from olmo_core.distributed.parallel import DataParallelType
from olmo_core.distributed.utils import init_hybrid_shard_mesh
from olmo_core.nn.transformer import TransformerConfig, TransformerDataParallelConfig
from olmo_core.optim import AdamConfig, CosWithWarmup
from olmo_core.train import (
Expand Down Expand Up @@ -174,7 +173,6 @@ def main(run_name: str, overrides: List[str]):
model = config.model.build(
init_device="meta",
device=get_default_device(),
dp_mesh=init_hybrid_shard_mesh(num_replicas=2),
)
optim = config.optim.build(model)
dataset = config.dataset.build()
Expand Down
20 changes: 19 additions & 1 deletion src/olmo_core/data/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@
import torch
import torch.distributed as dist
import torch.utils.data
from torch.distributed import DeviceMesh

from ..aliases import PathOrStr
from ..config import Config
from ..distributed.parallel import get_dp_process_group
from ..distributed.utils import barrier, get_fs_local_rank, get_rank, get_world_size
from ..exceptions import OLMoConfigurationError
from ..utils import get_default_device, roundrobin, threaded_generator
Expand Down Expand Up @@ -430,9 +432,14 @@ def reshuffle(self, epoch: Optional[int] = None, in_memory: bool = False, **kwar
self.build_and_save_global_indices(in_memory=in_memory)

def get_mock_batch(self) -> Dict[str, Any]:
rng = torch.Generator()
rng.manual_seed(self.seed + self.dp_rank)
num_instances = self.rank_batch_size // self.dataset.max_sequence_length
input_ids = torch.randint(
0, self.dataset.vocab_size, (num_instances, self.dataset.max_sequence_length)
0,
self.dataset.vocab_size,
(num_instances, self.dataset.max_sequence_length),
generator=rng,
)
return {"input_ids": input_ids}

Expand Down Expand Up @@ -908,14 +915,25 @@ def build(
dataset: NumpyDatasetBase,
*,
collator: Optional[DataCollator] = None,
mesh: Optional[DeviceMesh] = None,
dp_process_group: Optional[dist.ProcessGroup] = None,
) -> NumpyDataLoaderBase:
"""
Construct the :class:`NumpyDataLoaderBase`.

:param dataset: The dataset.
:param mesh: An optional ``DeviceMesh`` that defines the data parallel dimensions. Ideally
you should create this mesh using :func:`~olmo_core.distributed.parallel.build_device_mesh()`
or equivalently :meth:`olmo_core.nn.transformer.TransformerConfig.build_mesh()`.
Alternatively you can pass the ``dp_process_group`` instead.
:param dp_process_group: The data parallel process group.
"""
if self.work_dir is not None and not dataset.work_dir_set:
dataset.work_dir = Path(self.work_dir)

if dp_process_group is None and mesh is not None:
dp_process_group = get_dp_process_group(mesh)

dataset.prepare()

data_loader = NumpyDataLoaderBase.wrap_numpy_dataset(
Expand Down
162 changes: 160 additions & 2 deletions src/olmo_core/distributed/parallel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,161 @@
from .data_parallel import DataParallelConfig, DataParallelType
import logging
from typing import List, Optional

__all__ = ["DataParallelType", "DataParallelConfig"]
from torch.distributed import DeviceMesh, ProcessGroup, init_device_mesh

from olmo_core.config import StrEnum
from olmo_core.distributed.utils import get_num_nodes, get_world_size
from olmo_core.exceptions import OLMoConfigurationError
from olmo_core.utils import get_default_device

from .data_parallel import DataParallelConfig, DataParallelType, DPMeshDimName
from .tensor_parallel import TensorParallelConfig

__all__ = [
"build_device_mesh",
"MeshDimName",
"get_dp_mesh",
"get_tp_mesh",
"get_dp_process_group",
"DataParallelType",
"DataParallelConfig",
"DPMeshDimName",
"TensorParallelConfig",
]

log = logging.getLogger(__name__)


class MeshDimName(StrEnum):
"""
``DeviceMesh`` dimensions names for different forms of parallelism.
"""

dp = "dp"
"""
Data parallel (DP).
"""

dp_replicate = DPMeshDimName.replicate
"""
The DP dimension over which the model is replicated.
"""

dp_shard = DPMeshDimName.shard
"""
The DP dimension over which the model is sharded.
"""

tp = "tp"
"""
Tensor parallel (TP).
"""


def build_device_mesh(
*,
dp: Optional[DataParallelConfig] = None,
tp: Optional[TensorParallelConfig] = None,
device_type: Optional[str] = None,
) -> Optional[DeviceMesh]:
"""
Build a ``DeviceMesh`` suitable for the given parallel strategies.
The resulting dimension names will be defined in :class:`MeshDimName`.
"""
device_type = device_type or get_default_device().type

if tp is None and dp is None:
return None
elif tp is None:
assert dp is not None
return dp.build_device_mesh(device_type=device_type)
else:
assert dp is not None
assert tp is not None

if get_world_size() % tp.degree != 0:
raise OLMoConfigurationError(
f"World size {get_world_size()} must be divisible by TP degree ({tp.degree})"
)

dp_world_size = get_world_size() // tp.degree

dims: List[int] = []
names: List[str] = []

if dp.name == DataParallelType.hsdp:
num_replicas = dp.num_replicas or get_num_nodes()
if dp_world_size % num_replicas != 0:
raise OLMoConfigurationError(
f"HSDP requires DP world size ({dp_world_size}) to be divisible by 'num_replicas' ({num_replicas})"
)
dims.append(num_replicas)
dims.append(dp_world_size // num_replicas)
names.append(MeshDimName.dp_replicate)
names.append(MeshDimName.dp_shard)
else:
dims.append(dp_world_size)
names.append(MeshDimName.dp)

dims.append(tp.degree)
names.append(MeshDimName.tp)

log.info(f"Building {len(dims)}-D device mesh with {names}, {dims}...")

return init_device_mesh(device_type, tuple(dims), mesh_dim_names=tuple(names))


def get_dp_mesh(device_mesh: Optional[DeviceMesh] = None) -> Optional[DeviceMesh]:
"""
Get the data parallel sub-mesh associated with a ``DeviceMesh`` that was potentially
created from :func:`build_device_mesh()`.
"""
if device_mesh is None:
return None

if device_mesh.mesh_dim_names is None:
raise RuntimeError("could not determine data parallel sub-mesh without dimension names")

if MeshDimName.dp in device_mesh.mesh_dim_names:
return device_mesh[MeshDimName.dp]
elif (
MeshDimName.dp_replicate in device_mesh.mesh_dim_names
and MeshDimName.dp_shard in device_mesh.mesh_dim_names
):
return device_mesh[MeshDimName.dp_replicate, MeshDimName.dp_shard]
else:
raise RuntimeError(
f"could not determine data parallel sub-mesh from mesh with dimensions {device_mesh.mesh_dim_names}"
)


def get_dp_process_group(device_mesh: Optional[DeviceMesh] = None) -> Optional[ProcessGroup]:
"""
Get the data parallel process group associated with a ``DeviceMesh`` that was potentially
created from :func:`build_device_mesh()`.
"""
dp_mesh = get_dp_mesh(device_mesh)
if dp_mesh is None:
return None
else:
if len(dp_mesh.shape) > 1:
return dp_mesh._flatten(mesh_dim_name=MeshDimName.dp).get_group()
else:
return dp_mesh.get_group()


def get_tp_mesh(device_mesh: Optional[DeviceMesh] = None) -> Optional[DeviceMesh]:
"""
Get the tensor parallel sub-mesh associated with a ``DeviceMesh`` that was potentially
created from :func:`build_device_mesh()`.
"""
if device_mesh is None:
return None

if device_mesh.mesh_dim_names is None:
raise RuntimeError("could not determine tensor parallel sub-mesh without dimension names")

if MeshDimName.tp in device_mesh.mesh_dim_names:
return device_mesh[MeshDimName.tp]
else:
return None
46 changes: 46 additions & 0 deletions src/olmo_core/distributed/parallel/data_parallel.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,35 @@
import logging
from dataclasses import dataclass
from typing import Optional

from torch.distributed.device_mesh import DeviceMesh, init_device_mesh

from olmo_core.config import Config, DType, StrEnum
from olmo_core.distributed.utils import get_num_nodes, get_world_size
from olmo_core.exceptions import OLMoConfigurationError
from olmo_core.utils import get_default_device

log = logging.getLogger(__name__)


class DPMeshDimName(StrEnum):
"""
``DeviceMesh`` dimension names for data parallelism.
"""

replicate = "dp_replicate"
"""
The device mesh dimension over which the model is replicated.
"""
shard = "dp_shard"
"""
The device mesh dimension over which the model is sharded.
"""


class DataParallelType(StrEnum):
fsdp = "fsdp"
hsdp = "hsdp"
ddp = "ddp"


Expand All @@ -14,3 +38,25 @@ class DataParallelConfig(Config):
name: DataParallelType
param_dtype: Optional[DType] = None
reduce_dtype: DType = DType.float32
num_replicas: Optional[int] = None

def build_device_mesh(self, device_type: Optional[str] = None) -> Optional[DeviceMesh]:
"""
Build the optional device mesh needed for this config.
"""
if self.name == DataParallelType.hsdp:
num_replicas = self.num_replicas or get_num_nodes()
device_type = device_type or get_default_device().type
if get_world_size() % num_replicas != 0:
raise OLMoConfigurationError(
"HSDP requires world size to be divisible by 'num_replicas'"
)

log.info(f"Building device mesh for HSDP with {num_replicas} replicas...")
return init_device_mesh(
device_type,
(num_replicas, get_world_size() // num_replicas),
mesh_dim_names=(DPMeshDimName.replicate, DPMeshDimName.shard),
)
else:
return None
Loading
Loading