diff --git a/.github/workflows/run_pre_commit.yaml b/.github/workflows/run_pre_commit.yaml index eb7a4b214..e80facaeb 100644 --- a/.github/workflows/run_pre_commit.yaml +++ b/.github/workflows/run_pre_commit.yaml @@ -9,6 +9,7 @@ jobs: strategy: matrix: python-version: ["3.10"] + jax-version: ["0.4.14"] steps: - uses: actions/checkout@v3 @@ -20,16 +21,7 @@ jobs: run: | python -m pip install --upgrade pip pip install flake8 pytest pre-commit - pip install --upgrade "jax[cpu]==0.4.11" "jaxlib==0.4.11" - # install haliax from source b/c it's changing in parallel with this repo - pip install git+https://github.com/stanford-crfm/haliax.git - pip install . -# - name: Lint with flake8 -# run: | -# # stop the build if there are Python syntax errors or undefined names -# flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics -# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide -# flake8 . --count --exit-zero --max-complexity=50 --max-line-length=127 --statistics + pip install . "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" - name: "Run Pre-commit" run: | pre-commit run --all-files --show-diff-on-failure diff --git a/config/gpt2_small_pile.yaml b/config/gpt2_small_pile.yaml index ab7503871..19512c3dd 100644 --- a/config/gpt2_small_pile.yaml +++ b/config/gpt2_small_pile.yaml @@ -8,7 +8,7 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "pile", "gpt2"] diff --git a/config/gpt2_small_pile_mixture.yaml b/config/gpt2_small_pile_mixture.yaml index e02e4bd1f..a79ec8052 100644 --- a/config/gpt2_small_pile_mixture.yaml +++ b/config/gpt2_small_pile_mixture.yaml @@ -8,7 +8,7 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "pile", "gpt2"] diff --git a/examples/alpaca/alpaca.py b/examples/alpaca/alpaca.py index cfe07a1e4..de1fde555 100644 --- a/examples/alpaca/alpaca.py +++ b/examples/alpaca/alpaca.py @@ -230,9 +230,7 @@ def train(config: TrainArgs): def compute_loss(model: LmHeadModel, example: LmExample, key=None): return model.compute_loss(example, key=key).scalar() - trainer = Trainer(config.trainer, optimizer, compute_loss) - - with trainer.device_mesh: + with Trainer(config.trainer, optimizer, compute_loss) as trainer: # how we shard parameters across devices parameter_axis_mapping = trainer.parameter_axis_mapping diff --git a/src/levanter/__init__.py b/src/levanter/__init__.py index 548a113a0..3042154e2 100644 --- a/src/levanter/__init__.py +++ b/src/levanter/__init__.py @@ -8,4 +8,5 @@ import levanter.tracker as tracker import levanter.trainer as trainer import levanter.visualization as visualization +from levanter.tracker import current_tracker from levanter.trainer import initialize diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index 569bbe711..f5faa9b36 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -16,7 +16,6 @@ import pyarrow as pa import pyarrow.parquet as pq import ray -import wandb from dataclasses_json import dataclass_json from fsspec import AbstractFileSystem from ray.actor import ActorHandle @@ -31,6 +30,8 @@ TimeRemainingColumn, ) +import levanter.tracker + from .. import logging from ..utils.ray_utils import ExceptionInfo, RefBox, current_actor_handle, ser_exc_info from ._preprocessor import BatchProcessor, BatchResult, as_record_batch, dict_from_record_batch @@ -739,7 +740,7 @@ def __call__(self, metrics: InProgressCacheMetrics): self.last_metrics = metrics self.last_time = time.time() - wandb.log(to_log, commit=self.commit) + levanter.tracker.log_metrics(to_log, step=None, commit=self.commit) class LoggerMetricsMonitor(MetricsMonitor): diff --git a/src/levanter/main/eval_lm.py b/src/levanter/main/eval_lm.py index ab6d9d6b9..9b056b950 100644 --- a/src/levanter/main/eval_lm.py +++ b/src/levanter/main/eval_lm.py @@ -51,7 +51,11 @@ def main(config: EvalLmConfig): if config.eval_on_train: raw_dataset = CausalLmDataset(config.data.train_set(Pos.size), Pos, KeyPos) else: - raw_dataset = CausalLmDataset(config.data.validation_set(Pos.size), Pos, KeyPos) # type: ignore + validation_set = config.data.validation_set(Pos.size) + if validation_set is None: + raise ValueError("Can't eval on validation_set b/c there isn't one!") + + raw_dataset = CausalLmDataset(validation_set, Pos, KeyPos) # type: ignore eval_loader = ReplicatedBatchLoader(raw_dataset, config.trainer.device_mesh, Batch) compute_axis_mapping = config.trainer.compute_axis_mapping diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index 0d0a7d70a..c492b2321 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -123,7 +123,7 @@ def compute_loss( across the reduction axis (with reduction_axis=None meaning all axes). If reduction is None, the loss is not reduced, and the result is a named array with axes (*batch axes, sequence_length). """ - logits = self(example.tokens, example.attn_mask, key=key) + logits = self(example.tokens, example.attn_mask, key=key).astype(jnp.float32) targets = hax.roll(example.tokens, -1, axis=self.Pos.name) target_y = hax.nn.one_hot(targets, self.Vocab, dtype=logits.dtype) return cross_entropy_loss( diff --git a/src/levanter/optim/sophia.py b/src/levanter/optim/sophia.py index 8895942a2..daa03285e 100644 --- a/src/levanter/optim/sophia.py +++ b/src/levanter/optim/sophia.py @@ -11,7 +11,7 @@ from jax.random import PRNGKey from jaxtyping import PRNGKeyArray -# import levanter.tracker +import levanter.tracker from levanter.optim.config import HessianOptConfig, OptimizerConfig from levanter.optim.util import hvp, tree_gaussian_like from levanter.utils.jax_utils import parameter_count, tree_filter_like @@ -348,9 +348,7 @@ def update_fn(updates, state, params=None, *, obj_fn, **kwargs): updates = jax.tree_util.tree_map(lambda u: jnp.clip(u, -clip_threshold, clip_threshold), updates) stats["optim/unclipped_fraction"] = unclipped_count / parameter_count(updates) - # this doesn't work well on CPU, so skip if cpu - # if jax.lib.xla_bridge.get_backend().platform != "cpu": - # levanter.tracker.jit_log_metrics(stats, step=state.count) + levanter.tracker.jit_log_metrics(stats, step=state.count) if mu_dtype is not None: mu = jax.tree_util.tree_map(lambda t: t.astype(mu_dtype), mu) diff --git a/src/levanter/utils/hf_utils.py b/src/levanter/utils/hf_utils.py index 408a8c8da..ef77dcdfd 100644 --- a/src/levanter/utils/hf_utils.py +++ b/src/levanter/utils/hf_utils.py @@ -18,7 +18,8 @@ def num_cpus_used_by_tokenizer(tokenizer) -> int: else: # This is a bit hacky, but HF's fast tokenizers are parallelized under the hood. # we reserve a couple of cores just so Ray has somewhere to run the coordinator. - return min(max(1, logical_cpu_core_count() - 2), 32) + # Empirically it doesn't usually exceed 16-20, and it's useful to have some slack + return min(max(1, logical_cpu_core_count() - 2), 12) else: return 1