Skip to content

Commit

Permalink
Misc fixes (#461)
Browse files Browse the repository at this point in the history
* ensure we do the final softmax in fp32

* tweak number of cores for tokenization

* use the trainer better

* switch to tracker in shard_cache

* expose current_tracker in levanter base package

* add back in logging of sophia metrics

* config fixes

* missed a spot

* raise if there's validation set in eval_lm

* is precommit being crappy?

* sigh
  • Loading branch information
dlwh authored Feb 11, 2024
1 parent 5de54dc commit 50f72d7
Show file tree
Hide file tree
Showing 10 changed files with 19 additions and 24 deletions.
12 changes: 2 additions & 10 deletions .github/workflows/run_pre_commit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ jobs:
strategy:
matrix:
python-version: ["3.10"]
jax-version: ["0.4.14"]

steps:
- uses: actions/checkout@v3
Expand All @@ -20,16 +21,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install flake8 pytest pre-commit
pip install --upgrade "jax[cpu]==0.4.11" "jaxlib==0.4.11"
# install haliax from source b/c it's changing in parallel with this repo
pip install git+https://github.com/stanford-crfm/haliax.git
pip install .
# - name: Lint with flake8
# run: |
# # stop the build if there are Python syntax errors or undefined names
# flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
# flake8 . --count --exit-zero --max-complexity=50 --max-line-length=127 --statistics
pip install . "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}"
- name: "Run Pre-commit"
run: |
pre-commit run --all-files --show-diff-on-failure
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_small_pile.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ model:
gradient_checkpointing: true
scale_attn_by_inverse_layer_idx: true
trainer:
wandb:
tracker:
project: "levanter"
tags: [ "pile", "gpt2"]

Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_small_pile_mixture.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ model:
gradient_checkpointing: true
scale_attn_by_inverse_layer_idx: true
trainer:
wandb:
tracker:
project: "levanter"
tags: [ "pile", "gpt2"]

Expand Down
4 changes: 1 addition & 3 deletions examples/alpaca/alpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,7 @@ def train(config: TrainArgs):
def compute_loss(model: LmHeadModel, example: LmExample, key=None):
return model.compute_loss(example, key=key).scalar()

trainer = Trainer(config.trainer, optimizer, compute_loss)

with trainer.device_mesh:
with Trainer(config.trainer, optimizer, compute_loss) as trainer:
# how we shard parameters across devices
parameter_axis_mapping = trainer.parameter_axis_mapping

Expand Down
1 change: 1 addition & 0 deletions src/levanter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
import levanter.tracker as tracker
import levanter.trainer as trainer
import levanter.visualization as visualization
from levanter.tracker import current_tracker
from levanter.trainer import initialize
5 changes: 3 additions & 2 deletions src/levanter/data/shard_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import pyarrow as pa
import pyarrow.parquet as pq
import ray
import wandb
from dataclasses_json import dataclass_json
from fsspec import AbstractFileSystem
from ray.actor import ActorHandle
Expand All @@ -31,6 +30,8 @@
TimeRemainingColumn,
)

import levanter.tracker

from .. import logging
from ..utils.ray_utils import ExceptionInfo, RefBox, current_actor_handle, ser_exc_info
from ._preprocessor import BatchProcessor, BatchResult, as_record_batch, dict_from_record_batch
Expand Down Expand Up @@ -739,7 +740,7 @@ def __call__(self, metrics: InProgressCacheMetrics):
self.last_metrics = metrics
self.last_time = time.time()

wandb.log(to_log, commit=self.commit)
levanter.tracker.log_metrics(to_log, step=None, commit=self.commit)


class LoggerMetricsMonitor(MetricsMonitor):
Expand Down
6 changes: 5 additions & 1 deletion src/levanter/main/eval_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@ def main(config: EvalLmConfig):
if config.eval_on_train:
raw_dataset = CausalLmDataset(config.data.train_set(Pos.size), Pos, KeyPos)
else:
raw_dataset = CausalLmDataset(config.data.validation_set(Pos.size), Pos, KeyPos) # type: ignore
validation_set = config.data.validation_set(Pos.size)
if validation_set is None:
raise ValueError("Can't eval on validation_set b/c there isn't one!")

raw_dataset = CausalLmDataset(validation_set, Pos, KeyPos) # type: ignore

eval_loader = ReplicatedBatchLoader(raw_dataset, config.trainer.device_mesh, Batch)
compute_axis_mapping = config.trainer.compute_axis_mapping
Expand Down
2 changes: 1 addition & 1 deletion src/levanter/models/lm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def compute_loss(
across the reduction axis (with reduction_axis=None meaning all axes). If reduction is None, the loss is not
reduced, and the result is a named array with axes (*batch axes, sequence_length).
"""
logits = self(example.tokens, example.attn_mask, key=key)
logits = self(example.tokens, example.attn_mask, key=key).astype(jnp.float32)
targets = hax.roll(example.tokens, -1, axis=self.Pos.name)
target_y = hax.nn.one_hot(targets, self.Vocab, dtype=logits.dtype)
return cross_entropy_loss(
Expand Down
6 changes: 2 additions & 4 deletions src/levanter/optim/sophia.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from jax.random import PRNGKey
from jaxtyping import PRNGKeyArray

# import levanter.tracker
import levanter.tracker
from levanter.optim.config import HessianOptConfig, OptimizerConfig
from levanter.optim.util import hvp, tree_gaussian_like
from levanter.utils.jax_utils import parameter_count, tree_filter_like
Expand Down Expand Up @@ -348,9 +348,7 @@ def update_fn(updates, state, params=None, *, obj_fn, **kwargs):
updates = jax.tree_util.tree_map(lambda u: jnp.clip(u, -clip_threshold, clip_threshold), updates)
stats["optim/unclipped_fraction"] = unclipped_count / parameter_count(updates)

# this doesn't work well on CPU, so skip if cpu
# if jax.lib.xla_bridge.get_backend().platform != "cpu":
# levanter.tracker.jit_log_metrics(stats, step=state.count)
levanter.tracker.jit_log_metrics(stats, step=state.count)

if mu_dtype is not None:
mu = jax.tree_util.tree_map(lambda t: t.astype(mu_dtype), mu)
Expand Down
3 changes: 2 additions & 1 deletion src/levanter/utils/hf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def num_cpus_used_by_tokenizer(tokenizer) -> int:
else:
# This is a bit hacky, but HF's fast tokenizers are parallelized under the hood.
# we reserve a couple of cores just so Ray has somewhere to run the coordinator.
return min(max(1, logical_cpu_core_count() - 2), 32)
# Empirically it doesn't usually exceed 16-20, and it's useful to have some slack
return min(max(1, logical_cpu_core_count() - 2), 12)
else:
return 1

Expand Down

0 comments on commit 50f72d7

Please sign in to comment.