Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a callback for downstream evals, update Docker builds #73

Merged
merged 27 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ jobs:
src/test/

- name: Test checkpoint (GPU)
image: olmo-core
image: olmo-core-nightly
gpus: 2
run: |
pytest -v --color=yes --durations=3 -m gpu \
Expand Down Expand Up @@ -180,10 +180,11 @@ jobs:
gpuCount: ${{ matrix.task.gpus }}
constraints:
cluster:
- ai2/allennlp-cirrascale
- ai2/allennlp-elanding-a100-40g
# - ai2/allennlp-cirrascale
# - ai2/allennlp-elanding-a100-40g
- ai2/pluto-cirrascale
- ai2/jupiter-cirrascale-2
# - ai2/saturn-cirrascale
envVars:
- name: CUBLAS_WORKSPACE_CONFIG
value: ":16:8"
Expand Down
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added

- Added `DownstreamEvaluatorCallbackConfig` class for running in-loop downstream eval via [OLMo-in-loop-evals](https://github.com/allenai/OLMo-in-loop-evals).

### Removed

- Removed `flash-attn` from the Beaker images since `flash-attn` currently can't be built for torch 2.5.1. We are waiting on updates from the `flash-attn` maintainers. See https://github.com/Dao-AILab/flash-attention/issues/1302.

## [v1.5.0](https://github.com/allenai/OLMo-core/releases/tag/v1.5.0) - 2024-10-23

### Added
Expand Down
12 changes: 8 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
BASE_IMAGE = ghcr.io/allenai/pytorch:2.4.1-cuda12.1-python3.11
# NOTE: make sure CUDA versions match across these variables
BASE_IMAGE = ghcr.io/allenai/pytorch:2.5.1-cuda12.1-python3.11-v2024.10.29
CUDA_TOOLKIT_VERSION = 12.1.0
TORCH_CUDA_VERSION = 121

# NOTE: when upgrading the nightly version you also need to upgrade the torch version specification
# in 'pyproject.toml' to include that nightly version.
NIGHTLY_VERSION = "2.6.0.dev20241009+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121"
TORCHAO_VERSION = "torchao==0.5.0 --extra-index-url https://download.pytorch.org/whl/cu121"
NIGHTLY_VERSION = "2.6.0.dev20241009+cu121"
TORCHAO_VERSION = "torchao==0.5.0"
MEGABLOCKS_VERSION = "megablocks[gg] @ git+https://[email protected]/epwalsh/megablocks.git@epwalsh/deps"
CUDA_TOOLKIT_VERSION = 12.1.0

VERSION = $(shell python src/olmo_core/version.py)
VERSION_SHORT = $(shell python src/olmo_core/version.py short)
Expand Down Expand Up @@ -49,6 +51,7 @@ stable-image :
--build-arg BUILDKIT_INLINE_CACHE=1 \
--build-arg BASE=$(BASE_IMAGE) \
--build-arg CUDA_TOOLKIT_VERSION=$(CUDA_TOOLKIT_VERSION) \
--build-arg TORCH_CUDA_VERSION=$(TORCH_CUDA_VERSION) \
--build-arg MEGABLOCKS_VERSION=$(MEGABLOCKS_VERSION) \
--build-arg TORCHAO_VERSION=$(TORCHAO_VERSION) \
--target stable \
Expand All @@ -62,6 +65,7 @@ nightly-image :
--build-arg BUILDKIT_INLINE_CACHE=1 \
--build-arg BASE=$(BASE_IMAGE) \
--build-arg CUDA_TOOLKIT_VERSION=$(CUDA_TOOLKIT_VERSION) \
--build-arg TORCH_CUDA_VERSION=$(TORCH_CUDA_VERSION) \
--build-arg MEGABLOCKS_VERSION=$(MEGABLOCKS_VERSION) \
--build-arg TORCHAO_VERSION=$(TORCHAO_VERSION) \
--build-arg NIGHTLY_VERSION=$(NIGHTLY_VERSION) \
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies = [
"omegaconf",
"safetensors",
"importlib_resources",
"ai2-olmo-eval==0.2.0",
]

[project.urls]
Expand Down
29 changes: 25 additions & 4 deletions src/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,39 @@ WORKDIR /app/build
ARG CUDA_TOOLKIT_VERSION
RUN conda install -y -c nvidia cuda-toolkit==${CUDA_TOOLKIT_VERSION}

ARG TORCH_CUDA_VERSION

# Build megablocks and grouped-gemm.
ENV TORCH_CUDA_ARCH_LIST="8.0 9.0"
ENV GROUPED_GEMM_CUTLASS=1
ARG MEGABLOCKS_VERSION
RUN pip wheel --no-build-isolation --no-cache-dir "${MEGABLOCKS_VERSION}" \
&& rm -rf torch-*.whl numpy-*.whl triton-*.whl
RUN pip wheel --no-build-isolation --no-cache-dir \
--extra-index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} \
"${MEGABLOCKS_VERSION}"

# Flash-attn from pre-built wheel (can't get this to work at the moment)
#RUN wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4cxx11abiTRUE-cp311-cp311-linux_x86_64.whl

# Only keep the target wheels and dependencies with CUDA extensions.
RUN echo "Built wheels:" \
&& ls -lh . \
&& ls -1 | grep -Ev 'megablocks|grouped_gemm|stanford_stk|flash_attn' | xargs rm \
&& echo "Final wheels:" \
&& ls -lh .

#########################################################################
# Stable image
#########################################################################

FROM ${BASE} as stable

ARG TORCH_CUDA_VERSION

# Install torchao.
ARG TORCHAO_VERSION
RUN pip install --no-cache-dir ${TORCHAO_VERSION}
RUN pip install --no-cache-dir \
--extra-index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} \
${TORCHAO_VERSION}

# Copy and install wheels from build image.
COPY --from=build /app/build /app/build
Expand All @@ -50,5 +67,9 @@ WORKDIR /app/olmo-core

FROM stable as nightly

ARG TORCH_CUDA_VERSION

ARG NIGHTLY_VERSION
RUN pip install --no-cache-dir --pre torch==${NIGHTLY_VERSION}
RUN pip install --no-cache-dir --pre \
--index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} \
torch==${NIGHTLY_VERSION}
11 changes: 10 additions & 1 deletion src/examples/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
CheckpointerCallback,
CometCallback,
ConfigSaverCallback,
DownstreamEvaluatorCallbackConfig,
GPUMemoryMonitorCallback,
GradClipperCallback,
LMEvaluatorCallbackConfig,
Expand Down Expand Up @@ -133,7 +134,7 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig:
.with_callback("config_saver", ConfigSaverCallback())
.with_callback("profiler", ProfilerCallback(enabled=False))
.with_callback(
"evaluator",
"lm_evaluator",
LMEvaluatorCallbackConfig(
eval_dataset=NumpyDatasetConfig(
paths=["/net/nfs/allennlp/llm-data/c4/en/c4-validation.00000-00008.npy"],
Expand All @@ -147,6 +148,14 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig:
eval_duration=Duration.steps(10),
),
)
.with_callback(
"downstream_evaluator",
DownstreamEvaluatorCallbackConfig(
tasks=["hellaswag"],
tokenizer=tokenizer_config,
eval_interval=250,
),
)
)

return ExperimentConfig(
Expand Down
6 changes: 4 additions & 2 deletions src/olmo_core/data/mixes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from abc import abstractmethod
from contextlib import contextmanager
from pathlib import Path
from typing import Generator, List, Tuple
Expand All @@ -15,7 +16,8 @@ class DataMixBase(StrEnum):
Base class for enumeration of data mixes.
"""

def build(self, base_dir: str, tokenizer: TokenizerName) -> Tuple[List[str], List[str]]:
@abstractmethod
def build(self, base_dir: str, tokenizer: str) -> Tuple[List[str], List[str]]:
"""
Construct the data mix.

Expand All @@ -37,7 +39,7 @@ class DataMix(DataMixBase):
dolma17 = "dolma17"
v3_small_ppl_validation = "v3-small-ppl-validation"

def build(self, base_dir: str, tokenizer: TokenizerName) -> Tuple[List[str], List[str]]:
def build(self, base_dir: str, tokenizer: str) -> Tuple[List[str], List[str]]:
if not base_dir.endswith("/"):
base_dir = base_dir + "/"

Expand Down
48 changes: 46 additions & 2 deletions src/olmo_core/data/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,31 @@

from ..config import Config, StrEnum

__all__ = [
"TokenizerConfig",
"TokenizerName",
]


class TokenizerName(StrEnum):
"""
An enumeration of supported tokenizer names.
An enumeration of tokenizer identifiers commonly used OLMo researchers.
"""

dolma2 = "allenai/dolma2-tokenizer"
"""
The dolma2 tokenizer.
"""

gpt_neox_olmo_dolma_v1_5 = "allenai/gpt-neox-olmo-dolma-v1_5"
"""
A modified GPT NeoX tokenizer.
"""

gpt2 = "gpt2"
"""
The base GPT2 tokenizer.
"""


@dataclass
Expand All @@ -21,10 +37,29 @@ class TokenizerConfig(Config):
"""

vocab_size: int
"""
The vocab size.
"""

eos_token_id: int
"""
The end-of-sentence token ID.
"""

pad_token_id: int
"""
The padding token ID.
"""

bos_token_id: Optional[int] = None
identifier: Optional[TokenizerName] = None
"""
The begin-of-sentence token ID.
"""

identifier: Optional[str] = None
"""
The identifier of the tokenizer. Could be a path or HuggingFace identifier.
"""

def padded_vocab_size(self, pad_multiple: int = 128) -> int:
"""
Expand All @@ -35,6 +70,9 @@ def padded_vocab_size(self, pad_multiple: int = 128) -> int:

@classmethod
def dolma2(cls) -> "TokenizerConfig":
"""
Get a :data:`~TokenizerName.dolma2` tokenizer config.
"""
return cls(
vocab_size=100278,
eos_token_id=100257,
Expand All @@ -44,6 +82,9 @@ def dolma2(cls) -> "TokenizerConfig":

@classmethod
def gpt_neox_olmo_dolma_v1_5(cls) -> "TokenizerConfig":
"""
Get a :data:`~TokenizerName.gpt_neox_olmo_dolma_v1_5` tokenizer config.
"""
return cls(
vocab_size=50280,
eos_token_id=50279,
Expand All @@ -53,6 +94,9 @@ def gpt_neox_olmo_dolma_v1_5(cls) -> "TokenizerConfig":

@classmethod
def gpt2(cls) -> "TokenizerConfig":
"""
Get a :data:`~TokenizerName.gpt2` tokenizer config.
"""
return cls(
vocab_size=50280,
eos_token_id=50256,
Expand Down
19 changes: 18 additions & 1 deletion src/olmo_core/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from ..config import StrEnum
from ..exceptions import OLMoConfigurationError, OLMoEnvironmentError
from ..utils import get_default_device, set_env_var
from ..utils import get_default_device, move_to_device, set_env_var

OLMO_SHARED_FS_ENV_VAR = "OLMO_SHARED_FS"
OLMO_FS_LOCAL_RANK_ENV_VAR = "FS_LOCAL_RANK"
Expand Down Expand Up @@ -270,6 +270,23 @@ def scatter_object(obj: T, src: int = 0, group: Optional[dist.ProcessGroup] = No
return output_list[0]


def all_gather(
tensor: torch.Tensor, group: Optional[dist.ProcessGroup] = None
) -> List[torch.Tensor]:
"""
All-gather tensors from the whole group into a list.
"""
if not is_distributed():
return [tensor]

shapes = all_gather_object(tensor.shape, group=group)
output_list = [
move_to_device(torch.zeros(shape, dtype=tensor.dtype), tensor.device) for shape in shapes
]
dist.all_gather(output_list, tensor, group=group)
return output_list


def all_gather_object(obj: T, group: Optional[dist.ProcessGroup] = None) -> List[T]:
"""
All-gather an object using pickle to all ranks in a process group.
Expand Down
7 changes: 7 additions & 0 deletions src/olmo_core/eval/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def __init__(
def update(
self, value: Union[float, torch.Tensor], weight: Union[float, torch.Tensor] = 1.0
) -> None:
"""
:param value: The latest value to update the metric with. Could be a tensor of values.
:param weight: The corresponding weight(s) for the value. Should be the same shape as ``value``.
"""
value = self.as_tensor(value)
weight = torch.broadcast_to(self.as_tensor(weight), value.shape)
if value.numel() == 0:
Expand All @@ -75,6 +79,9 @@ def update(
self.weight += weight.sum()

def compute(self) -> torch.Tensor:
"""
Computes the mean over the values and weights given.
"""
weighted_sum = all_reduce_value(
self.weighted_sum, device=self.device, group=self.process_group
)
Expand Down
7 changes: 6 additions & 1 deletion src/olmo_core/train/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
from .comet import CometCallback, CometNotificationSetting
from .config_saver import ConfigSaverCallback
from .console_logger import ConsoleLoggerCallback
from .evaluator_callback import EvaluatorCallback, LMEvaluatorCallbackConfig
from .evaluator_callback import (
DownstreamEvaluatorCallbackConfig,
EvaluatorCallback,
LMEvaluatorCallbackConfig,
)
from .float8_handler import Float8HandlerCallback
from .garbage_collector import GarbageCollectorCallback
from .gpu_memory_monitor import GPUMemoryMonitorCallback
Expand All @@ -27,6 +31,7 @@
"EvaluatorCallback",
"Float8HandlerCallback",
"LMEvaluatorCallbackConfig",
"DownstreamEvaluatorCallbackConfig",
"MoEHandlerCallback",
"GarbageCollectorCallback",
"GPUMemoryMonitorCallback",
Expand Down
Loading
Loading