Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Feb 13, 2024
2 parents ddba7fd + 0b8b6e9 commit 001ea75
Show file tree
Hide file tree
Showing 11 changed files with 148 additions and 88 deletions.
12 changes: 2 additions & 10 deletions .github/workflows/run_pre_commit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ jobs:
strategy:
matrix:
python-version: ["3.10"]
jax-version: ["0.4.14"]

steps:
- uses: actions/checkout@v3
Expand All @@ -20,16 +21,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install flake8 pytest pre-commit
pip install --upgrade "jax[cpu]==0.4.11" "jaxlib==0.4.11"
# install haliax from source b/c it's changing in parallel with this repo
pip install git+https://github.com/stanford-crfm/haliax.git
pip install .
# - name: Lint with flake8
# run: |
# # stop the build if there are Python syntax errors or undefined names
# flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
# flake8 . --count --exit-zero --max-complexity=50 --max-line-length=127 --statistics
pip install . "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}"
- name: "Run Pre-commit"
run: |
pre-commit run --all-files --show-diff-on-failure
Expand Down
32 changes: 32 additions & 0 deletions config/gpt2_1536_sophiah.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
data:
train_urls:
- "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz"
validation_urls:
- "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz"
cache_dir: "gs://levanter-data/tokenized/openwebtext/"
tokenizer: "gpt2"
model:
type: gpt2
hidden_dim: 1536
num_heads: 24
num_layers: 48
seq_len: 1024
gradient_checkpointing: true
scale_attn_by_inverse_layer_idx: true
trainer:
tracker:
project: "levanter"
tags: [ "openwebtext", "gpt2"]

mp: p=f32,c=bfloat16
model_axis_size: 1
per_device_parallelism: 2
per_device_eval_parallelism: 8
optimizer:
type: sophia-h
learning_rate: 2E-4
weight_decay: 0.2
min_lr_ratio: 0.1
gamma: 0.01
# sophia needs a minimum amount of warmup or it doesn't do well
warmup: 2000
68 changes: 34 additions & 34 deletions examples/alpaca-lora/alpaca_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,40 +114,40 @@ def loraize_hf_model(model):
}
)

logger.info(f"Total parameter count: {all_param_count}")
logger.info(f"Trainable parameter count: {just_lora_params}")
logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}")

# Levanter has two kinds of data loaders: sharded and replicated. Replicated is simpler and allows for
# single pass training. Sharded only loads a subset of the data on each device, and is more efficient for large
# datasets. We use replicated here since the dataset is small.
loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch)
loader = non_caching_cycle(loader)

if state.step != 0:
logger.info(f"Resuming training from step {state.step}")
for i in range(state.step):
next(loader) # type: ignore

# Save HF PEFT checkpoints periodically (and at the end of training), which is just the lora weights
if config.hf_save_path is not None:
full_save_path = os.path.join(config.hf_save_path, trainer.run_id)
trainer.add_hook(
save_peft_checkpoint_callback(
full_save_path, config.lora, config.model_name_or_path, tokenizer, config.hf_upload
),
every=config.hf_save_steps,
)

# Save merged HF checkpoints if requested
if config.merged_hf_save_path is not None:
full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id)
trainer.add_hook(
save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload),
every=config.hf_save_steps,
)

trainer.train(state, loader)
logger.info(f"Total parameter count: {all_param_count}")
logger.info(f"Trainable parameter count: {just_lora_params}")
logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}")

# Levanter has two kinds of data loaders: sharded and replicated. Replicated is simpler and allows for
# single pass training. Sharded only loads a subset of the data on each device, and is more efficient for large
# datasets. We use replicated here since the dataset is small.
loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch)
loader = non_caching_cycle(loader)

if state.step != 0:
logger.info(f"Resuming training from step {state.step}")
for i in range(state.step):
next(loader) # type: ignore

# Save HF PEFT checkpoints periodically (and at the end of training), which is just the lora weights
if config.hf_save_path is not None:
full_save_path = os.path.join(config.hf_save_path, trainer.run_id)
trainer.add_hook(
save_peft_checkpoint_callback(
full_save_path, config.lora, config.model_name_or_path, tokenizer, config.hf_upload
),
every=config.hf_save_steps,
)

# Save merged HF checkpoints if requested
if config.merged_hf_save_path is not None:
full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id)
trainer.add_hook(
save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload),
every=config.hf_save_steps,
)

trainer.train(state, loader)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion src/levanter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@
import levanter.tracker as tracker
import levanter.trainer as trainer
import levanter.visualization as visualization
from levanter.tracker import current_tracker, get_tracker
from levanter.tracker import current_tracker
from levanter.trainer import initialize
5 changes: 3 additions & 2 deletions src/levanter/data/sharded_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,10 @@ class WrappedHFDataset(ShardedDataset[dict]):
kwargs are passed to load_dataset
"""

def __init__(self, id, *, split, **kwargs):
def __init__(self, id, *, split, streaming: bool = True, **kwargs):
self.id = id
self.split = split
self.streaming = streaming
self.kwargs = kwargs
self._shard_names = self._compute_shard_names()

Expand Down Expand Up @@ -184,7 +185,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]:
def _load_dataset(self):
# obnoxiously, the dataset loading stuff doesn't work with ray because of multiprocessing
# so we have to do this hacky thing where we load the dataset in the worker
return datasets.load_dataset(self.id, split=self.split, **self.kwargs)
return datasets.load_dataset(self.id, split=self.split, streaming=self.streaming, **self.kwargs)


class TextUrlDataset(ShardedDataset[str]):
Expand Down
2 changes: 2 additions & 0 deletions src/levanter/main/lora_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def loraize_hf_model(model):
state = trainer.initial_state(training_key, model=model, is_trainable=lora_param_filter)

all_param_count = parameter_count(state.model)
# TODO: remove this once we put this in trainer itself
just_lora_params = parameter_count(eqx.filter(state.model, lora_param_filter))

levanter.tracker.log_summary(
Expand All @@ -118,6 +119,7 @@ def loraize_hf_model(model):
train_loader = trainer.sharded_loader(train_dataset, Batch)

# boilerplate hooks and such
trainer.add_eval_hook(eval_dataset)
trainer.add_hook(callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size), every=1)
if config.peft_save_path is not None:
full_save_path = os.path.join(config.peft_save_path, trainer.run_id)
Expand Down
6 changes: 2 additions & 4 deletions src/levanter/optim/sophia.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from jax.random import PRNGKey
from jaxtyping import PRNGKeyArray

# import levanter.tracker
import levanter.tracker
from levanter.optim.config import HessianOptConfig, OptimizerConfig
from levanter.optim.util import hvp, tree_gaussian_like
from levanter.utils.jax_utils import parameter_count, tree_filter_like
Expand Down Expand Up @@ -348,9 +348,7 @@ def update_fn(updates, state, params=None, *, obj_fn, **kwargs):
updates = jax.tree_util.tree_map(lambda u: jnp.clip(u, -clip_threshold, clip_threshold), updates)
stats["optim/unclipped_fraction"] = unclipped_count / parameter_count(updates)

# this doesn't work well on CPU, so skip if cpu
# if jax.lib.xla_bridge.get_backend().platform != "cpu":
# levanter.tracker.jit_log_metrics(stats, step=state.count)
levanter.tracker.jit_log_metrics(stats, step=state.count)

if mu_dtype is not None:
mu = jax.tree_util.tree_map(lambda t: t.astype(mu_dtype), mu)
Expand Down
60 changes: 33 additions & 27 deletions src/levanter/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,29 +259,30 @@ def EvalBatch(self):
return self.config.EvalBatch

def __enter__(self):
if len(self._cmanagers) > 0:
raise RuntimeError("Trainer is already entered")

this_managers = [
self._cmanagers = [
levanter.current_tracker(self.tracker),
self.device_mesh,
hax.axis_mapping(self.parameter_axis_mapping),
]
self._cmanagers.append(this_managers)

for cmanager in this_managers:
for cmanager in self._cmanagers:
cmanager.__enter__()

return self

def __exit__(self, *args):
assert len(self._cmanagers) > 0, "Trainer.__exit__ called without corresponding Trainer.__enter__"
cur_managers = self._cmanagers.pop()
problems = []
for cmanager in reversed(cur_managers):
for cmanager in reversed(self._cmanagers):
try:
cmanager.__exit__(*args)
except Exception as e:
problems.append(e)

self._cmanagers = []

if len(problems) > 0:
raise RuntimeError("Exception(s) occurred while exiting trainer", problems) from problems[0]

Expand Down Expand Up @@ -395,23 +396,23 @@ def training_steps(
Generator that yields training steps and runs hooks.
"""
iter_data = iter(train_loader)
with levanter.current_tracker(self.tracker):
while state.step < self.num_train_steps:
with capture_time() as loading_time:
example = next(iter_data)

levanter.tracker.log_metrics({"throughput/loading_time": loading_time()}, step=state.step)
while state.step < self.num_train_steps:
with capture_time() as loading_time:
example = next(iter_data)

info = self.train_step(state, example)
levanter.tracker.log_metrics({"throughput/loading_time": loading_time()}, step=state.step)

if run_hooks:
with capture_time() as hook_time:
self.run_hooks(info)
info = self.train_step(state, example)
state = info.state

levanter.tracker.log_metrics({"throughput/hook_time": hook_time()}, step=state.step)
if run_hooks:
with capture_time() as hook_time:
self.run_hooks(info)

state = info.state
yield info
levanter.tracker.log_metrics({"throughput/hook_time": hook_time()}, step=state.step)

yield info

def train(self, state: TrainerState[M], train_loader: Iterable[X], run_hooks: bool = True) -> StepInfo[M]:
"""
Expand Down Expand Up @@ -488,21 +489,21 @@ def _train_step(self, state: TrainerState, *batch, **batch_kwargs) -> tuple[Scal
key, new_key = jax.random.split(state.training_key)
model = inference_mode(state.model, False)

loss, grads = self._compute_gradients_microbatched(model, batch, **batch_kwargs, key=key)
loss, grads = self._compute_gradients_microbatched(self.loss_fn, model, batch, **batch_kwargs, key=key)

new_state = self._take_train_step(state, model, grads, *batch, **batch_kwargs, key=key)
new_state = dataclasses.replace(new_state, training_key=new_key)

return loss, new_state

def _compute_gradients_microbatched(self, model: M, batch, **batch_kwargs) -> tuple[Scalar, M]:
grad_fn = eqx.filter_value_and_grad(self.loss_fn, has_aux=False)
def _compute_gradients_microbatched(self, loss_fn, model: M, batch, **batch_kwargs) -> tuple[Scalar, M]:
grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=False)
grad_fn = microbatched(
grad_fn,
self.TrainBatch,
self.config.per_device_parallelism,
self.parameter_axis_mapping,
self.config.microbatch_size,
self.parameter_axis_mapping,
self.compute_axis_mapping,
)
return grad_fn(model, *batch, **batch_kwargs)

Expand All @@ -512,13 +513,14 @@ def _take_train_step(self, state: S, model, grads, *batch, **batch_kwargs) -> S:
"""
# only train on the trainable parameters. We're leaning on JAX to do dead code elimination for us
with hax.axis_mapping(self.parameter_axis_mapping):
opt_state = state.opt_state
train_grads = _partition_trainable_params(grads, state.is_trainable)[0]
trainable_model = _partition_trainable_params(model, state.is_trainable)[0]
updates, opt_state = self.optimizer.update(train_grads, state.opt_state, params=trainable_model)

partial_fn = lambda model: self.loss_fn(model, *batch, **batch_kwargs)

updates, opt_state = self.optimizer.update(grads, opt_state, params=trainable_model, obj_fn=partial_fn)
updates, opt_state = self.optimizer.update(
train_grads, opt_state, params=trainable_model, obj_fn=partial_fn
)
model = eqx.apply_updates(model, updates)

return dataclasses.replace(state, _step=state._step + 1, model=model, opt_state=opt_state)
Expand Down Expand Up @@ -787,7 +789,7 @@ class AllConfig(Protocol):

def initialize(config: TrainerConfig | AllConfig):
"""Initializes jax, logging, setting the run name/id in the process. Also initializes tracking and saves config
as hyperparameters and as an artifact"""
as hyperparameters and an artifact"""
if isinstance(config, TrainerConfig):
trainer_config = config
else:
Expand All @@ -797,6 +799,10 @@ def initialize(config: TrainerConfig | AllConfig):
levanter.tracker.log_configuration(config)


def _params_only(t):
return eqx.filter(t, is_inexact_arrayish)


def _partition_trainable_params(model, filter):
"""
Partitions the model into trainable and non-trainable parameters. This is used internally
Expand Down
3 changes: 2 additions & 1 deletion src/levanter/utils/hf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def num_cpus_used_by_tokenizer(tokenizer) -> int:
else:
# This is a bit hacky, but HF's fast tokenizers are parallelized under the hood.
# we reserve a couple of cores just so Ray has somewhere to run the coordinator.
return min(max(1, logical_cpu_core_count() - 2), 32)
# Empirically it doesn't usually exceed 16-20, and it's useful to have some slack
return min(max(1, logical_cpu_core_count() - 2), 12)
else:
return 1

Expand Down
21 changes: 12 additions & 9 deletions tests/test_doremi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import haliax as hax

import levanter.tracker
from levanter.callbacks import eval_loss_loop
from levanter.data.dataset import ShardableDataset
from levanter.data.mixture import MixtureDataset
Expand Down Expand Up @@ -136,15 +137,17 @@ def init_model():
assert l3_ref < l1_ref < l2_ref

from levanter.doremi import estimate_mixture_weights

w = estimate_mixture_weights(
initial_proxy=init_model(),
ref=ref_model,
data_sources=datasets,
trainer_config=tiny_trainer_config,
key=next(keys),
loss_fn=compute_loss_fn,
)
from levanter.tracker import NoopTracker

with levanter.tracker.current_tracker(NoopTracker()):
w = estimate_mixture_weights(
initial_proxy=init_model(),
ref=ref_model,
data_sources=datasets,
trainer_config=tiny_trainer_config,
key=next(keys),
loss_fn=compute_loss_fn,
)

w1 = w["d1"]
w2 = w["d2"]
Expand Down
Loading

0 comments on commit 001ea75

Please sign in to comment.