diff --git a/.github/workflows/run_pre_commit.yaml b/.github/workflows/run_pre_commit.yaml index eb7a4b214..e80facaeb 100644 --- a/.github/workflows/run_pre_commit.yaml +++ b/.github/workflows/run_pre_commit.yaml @@ -9,6 +9,7 @@ jobs: strategy: matrix: python-version: ["3.10"] + jax-version: ["0.4.14"] steps: - uses: actions/checkout@v3 @@ -20,16 +21,7 @@ jobs: run: | python -m pip install --upgrade pip pip install flake8 pytest pre-commit - pip install --upgrade "jax[cpu]==0.4.11" "jaxlib==0.4.11" - # install haliax from source b/c it's changing in parallel with this repo - pip install git+https://github.com/stanford-crfm/haliax.git - pip install . -# - name: Lint with flake8 -# run: | -# # stop the build if there are Python syntax errors or undefined names -# flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics -# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide -# flake8 . --count --exit-zero --max-complexity=50 --max-line-length=127 --statistics + pip install . "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" - name: "Run Pre-commit" run: | pre-commit run --all-files --show-diff-on-failure diff --git a/config/gpt2_1536_sophiah.yaml b/config/gpt2_1536_sophiah.yaml new file mode 100644 index 000000000..0d1008106 --- /dev/null +++ b/config/gpt2_1536_sophiah.yaml @@ -0,0 +1,32 @@ +data: + train_urls: + - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" + validation_urls: + - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" + cache_dir: "gs://levanter-data/tokenized/openwebtext/" + tokenizer: "gpt2" +model: + type: gpt2 + hidden_dim: 1536 + num_heads: 24 + num_layers: 48 + seq_len: 1024 + gradient_checkpointing: true + scale_attn_by_inverse_layer_idx: true +trainer: + tracker: + project: "levanter" + tags: [ "openwebtext", "gpt2"] + + mp: p=f32,c=bfloat16 + model_axis_size: 1 + per_device_parallelism: 2 + per_device_eval_parallelism: 8 +optimizer: + type: sophia-h + learning_rate: 2E-4 + weight_decay: 0.2 + min_lr_ratio: 0.1 + gamma: 0.01 + # sophia needs a minimum amount of warmup or it doesn't do well + warmup: 2000 diff --git a/examples/alpaca-lora/alpaca_lora.py b/examples/alpaca-lora/alpaca_lora.py index eef14e026..c83beddaa 100644 --- a/examples/alpaca-lora/alpaca_lora.py +++ b/examples/alpaca-lora/alpaca_lora.py @@ -114,40 +114,40 @@ def loraize_hf_model(model): } ) - logger.info(f"Total parameter count: {all_param_count}") - logger.info(f"Trainable parameter count: {just_lora_params}") - logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") - - # Levanter has two kinds of data loaders: sharded and replicated. Replicated is simpler and allows for - # single pass training. Sharded only loads a subset of the data on each device, and is more efficient for large - # datasets. We use replicated here since the dataset is small. - loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) - loader = non_caching_cycle(loader) - - if state.step != 0: - logger.info(f"Resuming training from step {state.step}") - for i in range(state.step): - next(loader) # type: ignore - - # Save HF PEFT checkpoints periodically (and at the end of training), which is just the lora weights - if config.hf_save_path is not None: - full_save_path = os.path.join(config.hf_save_path, trainer.run_id) - trainer.add_hook( - save_peft_checkpoint_callback( - full_save_path, config.lora, config.model_name_or_path, tokenizer, config.hf_upload - ), - every=config.hf_save_steps, - ) - - # Save merged HF checkpoints if requested - if config.merged_hf_save_path is not None: - full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) - trainer.add_hook( - save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), - every=config.hf_save_steps, - ) - - trainer.train(state, loader) + logger.info(f"Total parameter count: {all_param_count}") + logger.info(f"Trainable parameter count: {just_lora_params}") + logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") + + # Levanter has two kinds of data loaders: sharded and replicated. Replicated is simpler and allows for + # single pass training. Sharded only loads a subset of the data on each device, and is more efficient for large + # datasets. We use replicated here since the dataset is small. + loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) + loader = non_caching_cycle(loader) + + if state.step != 0: + logger.info(f"Resuming training from step {state.step}") + for i in range(state.step): + next(loader) # type: ignore + + # Save HF PEFT checkpoints periodically (and at the end of training), which is just the lora weights + if config.hf_save_path is not None: + full_save_path = os.path.join(config.hf_save_path, trainer.run_id) + trainer.add_hook( + save_peft_checkpoint_callback( + full_save_path, config.lora, config.model_name_or_path, tokenizer, config.hf_upload + ), + every=config.hf_save_steps, + ) + + # Save merged HF checkpoints if requested + if config.merged_hf_save_path is not None: + full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) + trainer.add_hook( + save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), + every=config.hf_save_steps, + ) + + trainer.train(state, loader) if __name__ == "__main__": diff --git a/src/levanter/__init__.py b/src/levanter/__init__.py index ecabba8df..3042154e2 100644 --- a/src/levanter/__init__.py +++ b/src/levanter/__init__.py @@ -8,5 +8,5 @@ import levanter.tracker as tracker import levanter.trainer as trainer import levanter.visualization as visualization -from levanter.tracker import current_tracker, get_tracker +from levanter.tracker import current_tracker from levanter.trainer import initialize diff --git a/src/levanter/data/sharded_dataset.py b/src/levanter/data/sharded_dataset.py index 0ec178e08..d162882ac 100644 --- a/src/levanter/data/sharded_dataset.py +++ b/src/levanter/data/sharded_dataset.py @@ -147,9 +147,10 @@ class WrappedHFDataset(ShardedDataset[dict]): kwargs are passed to load_dataset """ - def __init__(self, id, *, split, **kwargs): + def __init__(self, id, *, split, streaming: bool = True, **kwargs): self.id = id self.split = split + self.streaming = streaming self.kwargs = kwargs self._shard_names = self._compute_shard_names() @@ -184,7 +185,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: def _load_dataset(self): # obnoxiously, the dataset loading stuff doesn't work with ray because of multiprocessing # so we have to do this hacky thing where we load the dataset in the worker - return datasets.load_dataset(self.id, split=self.split, **self.kwargs) + return datasets.load_dataset(self.id, split=self.split, streaming=self.streaming, **self.kwargs) class TextUrlDataset(ShardedDataset[str]): diff --git a/src/levanter/main/lora_lm.py b/src/levanter/main/lora_lm.py index 862aff5d5..ecad5e87f 100644 --- a/src/levanter/main/lora_lm.py +++ b/src/levanter/main/lora_lm.py @@ -92,6 +92,7 @@ def loraize_hf_model(model): state = trainer.initial_state(training_key, model=model, is_trainable=lora_param_filter) all_param_count = parameter_count(state.model) + # TODO: remove this once we put this in trainer itself just_lora_params = parameter_count(eqx.filter(state.model, lora_param_filter)) levanter.tracker.log_summary( @@ -118,6 +119,7 @@ def loraize_hf_model(model): train_loader = trainer.sharded_loader(train_dataset, Batch) # boilerplate hooks and such + trainer.add_eval_hook(eval_dataset) trainer.add_hook(callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size), every=1) if config.peft_save_path is not None: full_save_path = os.path.join(config.peft_save_path, trainer.run_id) diff --git a/src/levanter/optim/sophia.py b/src/levanter/optim/sophia.py index 8895942a2..daa03285e 100644 --- a/src/levanter/optim/sophia.py +++ b/src/levanter/optim/sophia.py @@ -11,7 +11,7 @@ from jax.random import PRNGKey from jaxtyping import PRNGKeyArray -# import levanter.tracker +import levanter.tracker from levanter.optim.config import HessianOptConfig, OptimizerConfig from levanter.optim.util import hvp, tree_gaussian_like from levanter.utils.jax_utils import parameter_count, tree_filter_like @@ -348,9 +348,7 @@ def update_fn(updates, state, params=None, *, obj_fn, **kwargs): updates = jax.tree_util.tree_map(lambda u: jnp.clip(u, -clip_threshold, clip_threshold), updates) stats["optim/unclipped_fraction"] = unclipped_count / parameter_count(updates) - # this doesn't work well on CPU, so skip if cpu - # if jax.lib.xla_bridge.get_backend().platform != "cpu": - # levanter.tracker.jit_log_metrics(stats, step=state.count) + levanter.tracker.jit_log_metrics(stats, step=state.count) if mu_dtype is not None: mu = jax.tree_util.tree_map(lambda t: t.astype(mu_dtype), mu) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 14aa98327..53611809f 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -259,29 +259,30 @@ def EvalBatch(self): return self.config.EvalBatch def __enter__(self): + if len(self._cmanagers) > 0: + raise RuntimeError("Trainer is already entered") - this_managers = [ + self._cmanagers = [ levanter.current_tracker(self.tracker), self.device_mesh, hax.axis_mapping(self.parameter_axis_mapping), ] - self._cmanagers.append(this_managers) - for cmanager in this_managers: + for cmanager in self._cmanagers: cmanager.__enter__() return self def __exit__(self, *args): - assert len(self._cmanagers) > 0, "Trainer.__exit__ called without corresponding Trainer.__enter__" - cur_managers = self._cmanagers.pop() problems = [] - for cmanager in reversed(cur_managers): + for cmanager in reversed(self._cmanagers): try: cmanager.__exit__(*args) except Exception as e: problems.append(e) + self._cmanagers = [] + if len(problems) > 0: raise RuntimeError("Exception(s) occurred while exiting trainer", problems) from problems[0] @@ -395,23 +396,23 @@ def training_steps( Generator that yields training steps and runs hooks. """ iter_data = iter(train_loader) - with levanter.current_tracker(self.tracker): - while state.step < self.num_train_steps: - with capture_time() as loading_time: - example = next(iter_data) - levanter.tracker.log_metrics({"throughput/loading_time": loading_time()}, step=state.step) + while state.step < self.num_train_steps: + with capture_time() as loading_time: + example = next(iter_data) - info = self.train_step(state, example) + levanter.tracker.log_metrics({"throughput/loading_time": loading_time()}, step=state.step) - if run_hooks: - with capture_time() as hook_time: - self.run_hooks(info) + info = self.train_step(state, example) + state = info.state - levanter.tracker.log_metrics({"throughput/hook_time": hook_time()}, step=state.step) + if run_hooks: + with capture_time() as hook_time: + self.run_hooks(info) - state = info.state - yield info + levanter.tracker.log_metrics({"throughput/hook_time": hook_time()}, step=state.step) + + yield info def train(self, state: TrainerState[M], train_loader: Iterable[X], run_hooks: bool = True) -> StepInfo[M]: """ @@ -488,21 +489,21 @@ def _train_step(self, state: TrainerState, *batch, **batch_kwargs) -> tuple[Scal key, new_key = jax.random.split(state.training_key) model = inference_mode(state.model, False) - loss, grads = self._compute_gradients_microbatched(model, batch, **batch_kwargs, key=key) + loss, grads = self._compute_gradients_microbatched(self.loss_fn, model, batch, **batch_kwargs, key=key) new_state = self._take_train_step(state, model, grads, *batch, **batch_kwargs, key=key) new_state = dataclasses.replace(new_state, training_key=new_key) return loss, new_state - def _compute_gradients_microbatched(self, model: M, batch, **batch_kwargs) -> tuple[Scalar, M]: - grad_fn = eqx.filter_value_and_grad(self.loss_fn, has_aux=False) + def _compute_gradients_microbatched(self, loss_fn, model: M, batch, **batch_kwargs) -> tuple[Scalar, M]: + grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=False) grad_fn = microbatched( grad_fn, self.TrainBatch, - self.config.per_device_parallelism, - self.parameter_axis_mapping, + self.config.microbatch_size, self.parameter_axis_mapping, + self.compute_axis_mapping, ) return grad_fn(model, *batch, **batch_kwargs) @@ -512,13 +513,14 @@ def _take_train_step(self, state: S, model, grads, *batch, **batch_kwargs) -> S: """ # only train on the trainable parameters. We're leaning on JAX to do dead code elimination for us with hax.axis_mapping(self.parameter_axis_mapping): + opt_state = state.opt_state train_grads = _partition_trainable_params(grads, state.is_trainable)[0] trainable_model = _partition_trainable_params(model, state.is_trainable)[0] - updates, opt_state = self.optimizer.update(train_grads, state.opt_state, params=trainable_model) - partial_fn = lambda model: self.loss_fn(model, *batch, **batch_kwargs) - updates, opt_state = self.optimizer.update(grads, opt_state, params=trainable_model, obj_fn=partial_fn) + updates, opt_state = self.optimizer.update( + train_grads, opt_state, params=trainable_model, obj_fn=partial_fn + ) model = eqx.apply_updates(model, updates) return dataclasses.replace(state, _step=state._step + 1, model=model, opt_state=opt_state) @@ -787,7 +789,7 @@ class AllConfig(Protocol): def initialize(config: TrainerConfig | AllConfig): """Initializes jax, logging, setting the run name/id in the process. Also initializes tracking and saves config - as hyperparameters and as an artifact""" + as hyperparameters and an artifact""" if isinstance(config, TrainerConfig): trainer_config = config else: @@ -797,6 +799,10 @@ def initialize(config: TrainerConfig | AllConfig): levanter.tracker.log_configuration(config) +def _params_only(t): + return eqx.filter(t, is_inexact_arrayish) + + def _partition_trainable_params(model, filter): """ Partitions the model into trainable and non-trainable parameters. This is used internally diff --git a/src/levanter/utils/hf_utils.py b/src/levanter/utils/hf_utils.py index 408a8c8da..ef77dcdfd 100644 --- a/src/levanter/utils/hf_utils.py +++ b/src/levanter/utils/hf_utils.py @@ -18,7 +18,8 @@ def num_cpus_used_by_tokenizer(tokenizer) -> int: else: # This is a bit hacky, but HF's fast tokenizers are parallelized under the hood. # we reserve a couple of cores just so Ray has somewhere to run the coordinator. - return min(max(1, logical_cpu_core_count() - 2), 32) + # Empirically it doesn't usually exceed 16-20, and it's useful to have some slack + return min(max(1, logical_cpu_core_count() - 2), 12) else: return 1 diff --git a/tests/test_doremi.py b/tests/test_doremi.py index c6ac76a47..64a84a191 100644 --- a/tests/test_doremi.py +++ b/tests/test_doremi.py @@ -6,6 +6,7 @@ import haliax as hax +import levanter.tracker from levanter.callbacks import eval_loss_loop from levanter.data.dataset import ShardableDataset from levanter.data.mixture import MixtureDataset @@ -136,15 +137,17 @@ def init_model(): assert l3_ref < l1_ref < l2_ref from levanter.doremi import estimate_mixture_weights - - w = estimate_mixture_weights( - initial_proxy=init_model(), - ref=ref_model, - data_sources=datasets, - trainer_config=tiny_trainer_config, - key=next(keys), - loss_fn=compute_loss_fn, - ) + from levanter.tracker import NoopTracker + + with levanter.tracker.current_tracker(NoopTracker()): + w = estimate_mixture_weights( + initial_proxy=init_model(), + ref=ref_model, + data_sources=datasets, + trainer_config=tiny_trainer_config, + key=next(keys), + loss_fn=compute_loss_fn, + ) w1 = w["d1"] w2 = w["d2"] diff --git a/tests/test_tensorstore_serialization.py b/tests/test_tensorstore_serialization.py index 398bad1f0..caf8a365d 100644 --- a/tests/test_tensorstore_serialization.py +++ b/tests/test_tensorstore_serialization.py @@ -1,10 +1,12 @@ from tempfile import TemporaryDirectory +from typing import Any import equinox as eqx import jax import jax.numpy as jnp import numpy as np import optax +import pytest from chex import assert_trees_all_close import haliax as hax @@ -127,3 +129,26 @@ def make_state(key): jax.tree_util.tree_leaves(arrays_only(restored_model)), jax.tree_util.tree_leaves(arrays_only(initial_model)), ) + + +def test_tensorstore_ok_with_nones(): + A = hax.Axis("A", 10) + + class MyModule(eqx.Module): + a: Any + b: Any + + m = MyModule(a=None, b=hax.zeros(A)) + m2 = MyModule(a=None, b=hax.ones(A)) + + with TemporaryDirectory() as tmpdir: + tree_serialize_leaves_tensorstore(tmpdir, m) + m3 = tree_deserialize_leaves_tensorstore(tmpdir, m2) + assert m3.a is None + assert hax.all(m3.b == hax.zeros(A)) + + m3 = MyModule(a=hax.zeros(A), b=hax.ones(A)) + with TemporaryDirectory() as tmpdir: + tree_serialize_leaves_tensorstore(tmpdir, m2) + with pytest.raises(ValueError): + tree_deserialize_leaves_tensorstore(tmpdir, m3)