diff --git a/docs/dev/Trackers.md b/docs/dev/Trackers.md index 1f1677d52..7726edabe 100644 --- a/docs/dev/Trackers.md +++ b/docs/dev/Trackers.md @@ -11,23 +11,23 @@ Given Levanter's historical dependency on W&B, the interface is designed to look The methods currently exposed are: * [levanter.tracker.current_tracker][]: returns the current tracker instance or sets it. -* [levanter.tracker.log_metrics][]: logs a dictionary of metrics for a given step. +* [levanter.tracker.log][]: logs a dictionary of metrics for a given step. * [levanter.tracker.log_summary][]: logs a dictionary of "summary" information, analogous to W&B's version. * [levanter.tracker.get_tracker][]: returns a tracker with the given name. -* [levanter.tracker.jit_log_metrics][]: a version of [levanter.tracker.log_metrics][] that works inside JAX jit. +* [levanter.tracker.jit_log][]: a version of [levanter.tracker.log][] that accumulates metrics inside of a `jit`-ted function. A basic example of using the tracker interface is shown below: ```python import wandb -from levanter.tracker import current_tracker, log_metrics, log_summary +import levanter.tracker as tracker from levanter.tracker.wandb import WandbTracker -with current_tracker(WandbTracker(wandb.init())): +with tracker.current_tracker(WandbTracker(wandb.init())): for step in range(100): - log_metrics({"loss": 100 -0.01 * step}, step=step) + tracker.log({"loss": 100 - 0.01 * step}, step=step) - log_summary({"best_loss": 0.0}) + tracker.log_summary({"best_loss": 0.0}) ``` A more typical example would be to use it in a config file, as we do with Trainer: @@ -73,13 +73,13 @@ TODO: expand this section. ::: levanter.tracker.current_tracker -::: levanter.tracker.log_metrics +::: levanter.tracker.log ::: levanter.tracker.log_summary ::: levanter.tracker.get_tracker -::: levanter.tracker.jit_log_metrics +::: levanter.tracker.jit_log ### Trackers diff --git a/infra/run.sh b/infra/run.sh index 9a0456fb7..7cbc0f969 100755 --- a/infra/run.sh +++ b/infra/run.sh @@ -1,5 +1,6 @@ umask 000 LEV_ROOT=$(dirname "$(readlink -f $0)")/.. +ulimit -s 65536 # figure out venv, first check if we wrote a path in infra/venv_path if [ ! -d "$VENV" ] && [ -f "$LEV_ROOT/infra/venv_path.txt" ]; then diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 983750685..135d10dd5 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -1,28 +1,33 @@ +import abc import copy import logging as pylogging import os -import re -import subprocess import sys -import tempfile import threading import time -import warnings +from abc import ABC from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass from datetime import timedelta -from typing import Callable, Optional +from typing import Any, Callable, Generic, Optional, TypeVar -import humanfriendly import jax +import jax.numpy as jnp +from jaxtyping import PyTree from tqdm_loggable import tqdm_logging from tqdm_loggable.auto import tqdm +import haliax.nn +from haliax import NamedArray, is_named_array +from haliax.jax_utils import is_jax_array_like + import levanter.tracker from levanter.data import AsyncDataset, DataLoader from levanter.tracker.helpers import log_optimizer_hyperparams +from levanter.tracker.histogram import Histogram from levanter.tracker.wandb import WandbConfig -from levanter.trainer import StepInfo -from levanter.utils import flop_utils +from levanter.trainer_state import TrainerState +from levanter.utils import flop_utils, jax_utils from levanter.utils.jax_utils import barrier_sync, jnp_to_python from levanter.utils.logging import save_xla_dumps_to_wandb from levanter.visualization import compute_and_visualize_log_probs as viz_probs @@ -30,6 +35,68 @@ logger = pylogging.getLogger(__name__) +M = TypeVar("M") # Model +M_con = TypeVar("M_con", bound=PyTree, contravariant=True) +S = TypeVar("S", bound=TrainerState) +CBInfo = TypeVar("CBInfo") + + +@dataclass +class StepInfo(Generic[S]): + """ + Information about a step that was just completed. This includes the trainer state, the loss, and the duration of the + step. + + Note that the step is 0-indexed, so if you want the next step, use `next_step`. + """ + + state: S + loss: float + step_duration: float + + model = property(lambda self: self.state.model) + opt_state = property(lambda self: self.state.opt_state) + + step = property(lambda self: int(self.state.step) - 1) + """ + The step that was just completed. If you want the next step, use `next_step`. + """ + + next_step = property(lambda self: int(self.state.step)) + + +class Callback(ABC, Generic[S]): + """ + A callback that can be called at the end of a step. This is useful for logging, profiling, and other side effects. + """ + + @abc.abstractmethod + def on_step(self, info: StepInfo[S], force: bool = False): + ... + + +class LambdaCallback(Callback[S]): + def __init__(self, fn: Callable[[StepInfo[S]], Any]): + self.fn = fn + + def on_step(self, info: StepInfo[S], force: bool = False): + self.fn(info) + + +class JitCallback(ABC, Generic[S, M, CBInfo]): + """ + A callback that gets called in two phases: inside the step (inside jit), and after the step (outside jit). + You have access to the gradients inside the step, so you can compute statistics on them. + """ + + @abc.abstractmethod + def inside_step(self, state: S, grad: M) -> CBInfo: + ... + + @abc.abstractmethod + def on_step(self, step_info: S, cb_info: CBInfo): + ... + def log_epoch_progress(total_tokens_future, tokens_per_example, batch_size, max_epochs: Optional[int] = None): total_tokens = None @@ -50,7 +117,7 @@ def log_epoch(step_info: StepInfo): total_tokens_for_epochs = total_tokens * max_epochs if max_epochs else total_tokens current_epoch = processed_tokens / total_tokens_for_epochs - levanter.tracker.log_metrics({"train/current_epoch": current_epoch}, step=step_info.step) + levanter.tracker.log({"train/current_epoch": current_epoch}, step=step_info.step) return log_epoch @@ -113,9 +180,6 @@ def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, n if n > 0: total_loss /= n - # logger.info(f"eval loading time: {total_load_time / n:.3f} s/ba") - # logger.info(f"eval loss time: {total_loss_time / n:.3f} s/ba") - return total_loss @@ -131,7 +195,7 @@ def compute_loss(info: StepInfo): prefix = "eval" if name: prefix += "/" + name - levanter.tracker.log_metrics({f"{prefix}/loss": loss}, step=info.step) + levanter.tracker.log({f"{prefix}/loss": loss}, step=info.step) if name: logger.info(f"{name} validation loss: {loss:.3f}") @@ -149,7 +213,7 @@ def log_step_info_inner(step: StepInfo): if total_steps: metrics["run_progress"] = step.step / total_steps log_optimizer_hyperparams(step.opt_state, step=step.step, prefix="optim") - levanter.tracker.log_metrics(metrics, step=step.step) + levanter.tracker.log(metrics, step=step.step) return log_step_info_inner @@ -226,7 +290,7 @@ def log_performance_stats(step_info: StepInfo): dict_to_log["mfu"] = mfu_instant dict_to_log = {wrap_key(k): v for k, v in dict_to_log.items()} - levanter.tracker.log_metrics(dict_to_log, step=step_info.step) + levanter.tracker.log(dict_to_log, step=step_info.step) return log_performance_stats @@ -248,100 +312,7 @@ def update_pbar(step: StepInfo): return update_pbar -def log_memory_usage(sample_interval: float = 1.0, log_individual_devices: bool = False): - """ - Logs memory usage. This runs a loop that samples memory usage every `sample_interval` seconds. - We only log when hooks are invoked, so there's not much point in running this much more frequently than you invoke - the hook. - - I think it's a good idea to run this in a separate thread, so that you sample from random points, but I'm not sure. - :param sample_interval: - :return: - """ - - directory = "/dev/shm" - # macos doesn't have /dev/shm - if not os.path.exists(directory): - directory = tempfile.gettempdir() - - tempfile_name = os.path.join(directory, f"memory_usage_{os.getpid()}.prof") - - # a lot of this code is lifted from https://github.com/ayaka14732/jax-smi CC-0 - - def inner(): - import posix - import time - - while True: - jax.profiler.save_device_memory_profile(f"{tempfile_name}.new") - posix.rename(f"{tempfile_name}.new", tempfile_name) - time.sleep(sample_interval) - - thread = threading.Thread(target=inner, daemon=True) - thread.start() - - def log_memory_usage(step: StepInfo): - process = subprocess.run( - args=f"go tool pprof -tags {tempfile_name}".split(" "), - stdout=subprocess.PIPE, - stderr=subprocess.DEVNULL, - ) - - if process.returncode != 0: - warnings.warn("failed to run pprof. Is go installed?") - return - - output = process.stdout.decode("utf-8") - - # output looks like this: - # 2.4MB (12.53%): TFRT_CPU_0 - # 2.4MB (12.50%): TFRT_CPU_1 - # 2.4MB (12.50%): TFRT_CPU_2 - # 2.4MB (12.50%): TFRT_CPU_3 - # 2.4MB (12.50%): TFRT_CPU_4 - # 2.4MB (12.50%): TFRT_CPU_5 - # 2.4MB (12.50%): TFRT_CPU_6 - # 2.4MB (12.50%): TFRT_CPU_7 - # - # kind: Total 19.5MB - # 18.9MB (97.20%): buffer - # 558.4kB ( 2.80%): executable - - # gpus look like this: - # 1.0MB ( 0.00%): gpu:0 - per_device, by_kind = output.split("kind: Total ") - - # first, get the total memory usage - regex = re.compile(r"^(\d+\.\d+[a-zA-Z]+)") - match = regex.search(by_kind) - if match: - memory_usage = humanfriendly.parse_size(match.group(1)) - levanter.tracker.log_metrics({"memory/total": memory_usage / 1e6}, step=step.step) - - # this works for the "kind" and the individual devices - regex = re.compile(r"([\d.]+[a-zA-Z]+) \(([\d.]+)%\): ([\w\d:_]+)") - - if log_individual_devices: - # now, get the memory usage per device. - # split the output at kind: Total - for match in regex.finditer(per_device): - memory_usage = humanfriendly.parse_size(match.group(1)) - device_name = match.group(3) - levanter.tracker.log_metrics({f"memory/device/{device_name}": memory_usage / 1e6}, step=step.step) - - # now, get the memory usage per kind. - # same regex as above - for match in regex.finditer(by_kind): - memory_usage = match.group(1) - memory_usage = humanfriendly.parse_size(memory_usage) - levanter.tracker.log_metrics({f"memory/{match.group(3)}": memory_usage / 1e6}, step=step.step) - - return log_memory_usage - - def profile(path: str, start_step: int, num_steps: int, create_perfetto_link: bool) -> Callable[[StepInfo], None]: - print(f"create_perfetto_link: {create_perfetto_link}") - def profiler_callback_fn(step: StepInfo): # -1 b/c step is the finished step if step.step == start_step - 1: @@ -423,3 +394,133 @@ def _tqdm_logging_one_time_setup(): return _did_tqdm_logging_one_time_setup = True tqdm_logging.tqdm_logging.set_log_rate(timedelta(seconds=60)) + + +class GradWatchCallback(JitCallback[S, M, dict[str, float | Histogram]]): + """ + Emulates the behavior of Wandb's PyTorch-only built-in gradient logging (wandb.watch) + + Args: + prefix (str): The prefix to use for logging. + include_histogram (bool): Whether to include histograms of the gradients. + split_scan_layers (bool): Whether to split the scan layers into separate histograms/norms + """ + + def __init__( + self, + prefix: str = "grad", + include_histogram: bool = True, + split_scan_layers: bool = True, + ): + self.prefix = prefix + self.include_histogram = include_histogram + self.split_scan_layers = split_scan_layers + + def inside_step(self, state: TrainerState[M], grad: M): + return summary_statistics_for_tree(self.prefix, grad, self.split_scan_layers, self.include_histogram) + + def on_step(self, step_info: StepInfo[S], cb_info: dict[str, float | Histogram]): + levanter.tracker.log(cb_info, step=step_info.step) + + +class ParamWatchCallback(JitCallback[S, M, dict[str, float | Histogram]]): + """ + Emulates the behavior of Wandb's PyTorch-only built-in gradient logging (wandb.watch) + + Args: + prefix (str): The prefix to use for logging. + include_histogram (bool): Whether to include histograms of the gradients. + split_scan_layers (bool): Whether to split the scan layers into separate histograms/norms + """ + + def __init__( + self, + prefix: str = "params", + include_histogram: bool = True, + split_scan_layers: bool = True, + ): + self.prefix = prefix + self.include_histogram = include_histogram + self.split_scan_layers = split_scan_layers + + def inside_step(self, state: TrainerState[M], grad: M): + return summary_statistics_for_tree( + self.prefix, state.trainable_model, self.split_scan_layers, self.include_histogram + ) + + def on_step(self, step_info: StepInfo[S], cb_info: dict[str, float | Histogram]): + levanter.tracker.log(cb_info, step=step_info.step) + + +def summary_statistics_for_tree( + prefix: str, tree: M, split_scan_layers: bool, include_histogram: bool +) -> dict[str, float | Histogram]: + """ + Computes the summary statistics for a tree of (named) arrays. + + This function is designed to allow you to emulate the behavior of Wandb's PyTorch-only built-in gradient logging, + but also works for any PyTree. It computes the Froebinius norm of each array, + and optionally the histogram as well. + + Args: + prefix: The prefix to use for logging. + tree: The tree of arrays to compute the summary statistics for. + split_scan_layers: Whether to split the scan layers into separate histograms/norms. Recommended. + include_histogram: Whether to include histograms of the gradients. This increases overhead significantly. + + Returns: + + """ + if split_scan_layers: + is_leaf = lambda n: isinstance(n, haliax.nn.Stacked) or is_named_array(n) # noqa: E731 + else: + is_leaf = is_named_array + + def _rec_log_magnitudes(norms, hists, path_prefix, tree): + leaf_key_paths = jax_utils.leaf_key_paths(tree, prefix=path_prefix, is_leaf=is_leaf) + del path_prefix + for key_path, g in zip( + jax.tree.leaves(leaf_key_paths, is_leaf=is_leaf), + jax.tree.leaves(tree, is_leaf=is_leaf), + strict=True, + ): + if split_scan_layers and isinstance(g, haliax.nn.Stacked): + vmapped_norms, vmapped_hists = haliax.vmap(_rec_log_magnitudes, g.Block)({}, {}, "", g.stacked) + + for k, v in vmapped_norms.items(): + for i in range(g.Block.size): + norms[f"{key_path}/{i}/{k}"] = v[i] + + for k, v in vmapped_hists.items(): + for i in range(g.Block.size): + hists[f"{key_path}/{i}/{k}"] = jax.tree.map(lambda x: x[i] if is_jax_array_like(x) else x, v) + + elif isinstance(g, NamedArray): + # TODO: add linalg.norm to Haliax + norms[key_path] = jnp.linalg.norm(g.array) + if include_histogram: + hist = Histogram.from_named_array(g) + hists[key_path] = hist + elif is_jax_array_like(g): + norms[key_path] = jnp.linalg.norm(g) + + if include_histogram: + hist = Histogram.from_array(g) + hists[key_path] = hist + + return norms, hists + + norms_to_log: dict[str, jax.Array] = {} + hists_to_log: dict[str, Histogram] = {} + + _rec_log_magnitudes(norms_to_log, hists_to_log, None, tree) + + to_log: dict = {} + + for key, value in norms_to_log.items(): + to_log[f"{prefix}/norm/{key}"] = value + + for key, value in hists_to_log.items(): + to_log[f"{prefix}/hist/{key}"] = value + + return to_log diff --git a/src/levanter/compat/hf_checkpoints.py b/src/levanter/compat/hf_checkpoints.py index dc2f0e16d..c61e3ac15 100644 --- a/src/levanter/compat/hf_checkpoints.py +++ b/src/levanter/compat/hf_checkpoints.py @@ -34,9 +34,9 @@ from haliax.partitioning import ResourceMapping from haliax.state_dict import from_torch_compatible_state_dict, save_state_dict, to_torch_compatible_state_dict +from levanter.callbacks import StepInfo from levanter.models.asr_model import ASRMixin from levanter.models.lm_model import LmConfig, LmHeadModel -from levanter.trainer import StepInfo from levanter.utils import jax_utils from levanter.utils.cloud_utils import temp_dir_before_upload from levanter.utils.hf_utils import HfTokenizer diff --git a/src/levanter/data/metrics_monitor.py b/src/levanter/data/metrics_monitor.py index 4e4619ffb..96c17ec65 100644 --- a/src/levanter/data/metrics_monitor.py +++ b/src/levanter/data/metrics_monitor.py @@ -122,7 +122,7 @@ def __call__(self, metrics: InProgressCacheMetrics): self.last_metrics = metrics self.last_time = time.time() - levanter.tracker.log_metrics(to_log, step=None, commit=self.commit) + levanter.tracker.log(to_log, step=None, commit=self.commit) class LoggerMetricsMonitor(MetricsMonitor): diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index c066e55d5..cdcfe68cd 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -12,12 +12,12 @@ from haliax.types import IntScalar import levanter.tracker -from levanter.callbacks import eval_loss_loop +from levanter.callbacks import M, StepInfo, eval_loss_loop from levanter.checkpoint import load_checkpoint_or_initialize from levanter.data import AsyncDataset, MappedAsyncDataset from levanter.data.mixture import MixtureDataset from levanter.tracker import capture_time -from levanter.trainer import M, StepInfo, Trainer, TrainerConfig, TrainerState +from levanter.trainer import Trainer, TrainerConfig, TrainerState from levanter.utils.tree_utils import inference_mode from levanter.utils.types import ComputeLossFunction @@ -112,7 +112,7 @@ def eval_loss(model, *batch, **batch_kwargs): max_batches=trainer_config.max_eval_batches, ) print(f"Loss of ref model on domain {domain}: {loss:.3f}") - levanter.tracker.log_metrics({f"eval/ref/{domain}/loss": loss}, step=0, commit=False) + levanter.tracker.log({f"eval/ref/{domain}/loss": loss}, step=0, commit=False) if validation_sets is not None: for domain, dataset in validation_sets.items(): @@ -167,7 +167,7 @@ def doremi_step(state: DoremiState, ref, batch, domains): # need to use where b/c we're in jit per_domain_dict = {k: jnp.where(v == 0.0, jnp.nan, v) for k, v in per_domain_dict.items()} - levanter.tracker.jit_log_metrics( + levanter.tracker.jit_log( { "change_in_alpha": alpha_distance.scalar(), "alpha_distance_from_uniform": distance_from_uniform.scalar(), diff --git a/src/levanter/eval.py b/src/levanter/eval.py index 6f40888cd..ada22bc14 100644 --- a/src/levanter/eval.py +++ b/src/levanter/eval.py @@ -16,9 +16,9 @@ from haliax.partitioning import ResourceMapping import levanter.tracker +from levanter.callbacks import StepInfo from levanter.data import AsyncDataset, DataLoader from levanter.models.lm_model import LmExample, LmHeadModel, compute_next_token_loss -from levanter.trainer import StepInfo from levanter.utils.hf_utils import HfTokenizer, byte_length_of_token from levanter.utils.logging import LoadingTimeTrackerIterator from levanter.utils.stat_utils import Arrayish, RunningMean @@ -238,7 +238,7 @@ def eval_callback(step: StepInfo): for tag, bpb in result.tag_macro_bpb.items(): log_dict[_join_prefix(prefix, tag) + "/macro_bpb"] = bpb - levanter.tracker.log_metrics(log_dict, step=step.step) + levanter.tracker.log(log_dict, step=step.step) return result diff --git a/src/levanter/grad_accum.py b/src/levanter/grad_accum.py index 8f102e1ad..4d167c562 100644 --- a/src/levanter/grad_accum.py +++ b/src/levanter/grad_accum.py @@ -11,7 +11,9 @@ import haliax as hax from haliax import Axis from haliax.partitioning import ResourceAxis -from haliax.util import is_jax_array_like, is_named_array +from haliax.util import is_named_array + +from levanter.utils.jax_utils import zeros_like_tree Args = ParamSpec("Args") @@ -90,7 +92,7 @@ def wrapped_fn(*args, **kwargs): # first, determine the shape and make accumulator arrays r_shape = eqx.filter_eval_shape(fn, *args, **kwargs) - acc = _zeros_like_tree(r_shape, accum_axis_mapping, accum_dtype) + acc = zeros_like_tree(r_shape, accum_axis_mapping, accum_dtype) # then, reshape the inputs from (Batch, ...) to (AccumStep, Microbatch, ...) @@ -147,19 +149,3 @@ def _reshape(x): return x return jax.tree_util.tree_map(_reshape, inputs, is_leaf=is_named_array) - - -def _zeros_like_tree(r_shape, axis_mapping, accum_dtype): - _zeros = functools.partial(_zeros_like, axis_mapping, accum_dtype) - acc = jax.tree_util.tree_map(_zeros, r_shape, is_leaf=is_named_array) - return acc - - -def _zeros_like(mapping, dtype, n): - if isinstance(n, hax.NamedArray): - return hax.shard(hax.zeros_like(n, dtype=dtype), mapping) - elif is_jax_array_like(n): - return jnp.zeros_like(n, dtype) - else: - assert jnp.isscalar(n) - return 0.0 diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index cea436942..75be8d206 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -246,6 +246,8 @@ def main(config: TrainLmConfig): trainer.add_hook( callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size, flops_per_example), every=1 ) + # trainer.add_hook(callbacks.GradWatchCallback(include_histogram=True), every=5) + if config.hf_save_path is not None: # bit gross to reach this far into the config, but it's fine if config.trainer.checkpointer.append_run_id_to_base_path: diff --git a/src/levanter/optim/sophia.py b/src/levanter/optim/sophia.py index 93b439681..0c8fb37b8 100644 --- a/src/levanter/optim/sophia.py +++ b/src/levanter/optim/sophia.py @@ -351,7 +351,7 @@ def update_fn(updates, state, params=None, *, obj_fn, **kwargs): updates = jax.tree_util.tree_map(lambda u: jnp.clip(u, -clip_threshold, clip_threshold), updates) stats["optim/unclipped_fraction"] = unclipped_count * 1.0 / float(parameter_count(updates)) - levanter.tracker.jit_log_metrics(stats, step=state.count) + levanter.tracker.jit_log(stats, step=state.count) if mu_dtype is not None: mu = jax.tree_util.tree_map(lambda t: t.astype(mu_dtype), mu) diff --git a/src/levanter/tracker/__init__.py b/src/levanter/tracker/__init__.py index 587c05d50..9bdaede65 100644 --- a/src/levanter/tracker/__init__.py +++ b/src/levanter/tracker/__init__.py @@ -3,7 +3,8 @@ from levanter.tracker.tracker_fns import ( current_tracker, get_tracker, - jit_log_metrics, + jit_log, + log, log_configuration, log_hyperparameters, log_metrics, @@ -20,11 +21,12 @@ "NoopTracker", "current_tracker", "get_tracker", - "jit_log_metrics", + "jit_log", "log_configuration", - "log_metrics", + "log", "log_summary", "log_hyperparameters", "set_global_tracker", "capture_time", + "log_metrics", ] diff --git a/src/levanter/tracker/helpers.py b/src/levanter/tracker/helpers.py index 92b943a2c..bd29f873e 100644 --- a/src/levanter/tracker/helpers.py +++ b/src/levanter/tracker/helpers.py @@ -30,7 +30,7 @@ def wrap_key(key): if hasattr(opt_state, "hyperparams"): params = {wrap_key(k): jnp_to_python(v) for k, v in opt_state.hyperparams.items()} - levanter.tracker.log_metrics(params, step=step) + levanter.tracker.log(params, step=step) def hparams_to_dict(hparams, **extra_hparams): diff --git a/src/levanter/tracker/histogram.py b/src/levanter/tracker/histogram.py new file mode 100644 index 000000000..9ab983a12 --- /dev/null +++ b/src/levanter/tracker/histogram.py @@ -0,0 +1,118 @@ +import functools + +import equinox +import jax +import jax.numpy as jnp +import numpy as np +from jax._src.partition_spec import PartitionSpec +from jax.experimental.shard_map import shard_map +from jaxtyping import ArrayLike, Scalar + +import haliax as hax +from haliax import NamedArray + + +class Histogram(equinox.Module): + """ + Has enough information to log to tensorboard and wandb + """ + + min: Scalar + max: Scalar + num: Scalar | int + sum: Scalar + sum_squares: Scalar + bucket_limits: jax.Array + bucket_counts: jax.Array + + @staticmethod + def from_array(array: jax.Array, num_bins: int = 64) -> "Histogram": + array = array.ravel() + min = array.min() + max = array.max() + num = array.size + sum = array.sum() + sum_squares = (array**2).sum() + counts, edges = jax.numpy.histogram(array, bins=num_bins) + return Histogram(min, max, num, sum, sum_squares, edges, counts) + + @staticmethod + def from_named_array(array: hax.NamedArray, num_bins: int = 64) -> "Histogram": + raw_array = array.array + min = raw_array.min() + max = raw_array.max() + num = array.size + sum = raw_array.sum() + sum_squares = (raw_array**2).sum() + counts, edges = sharded_histogram(array, bins=num_bins) + return Histogram(min, max, num, sum, sum_squares, edges, counts) + + def to_numpy_histogram(self) -> tuple[np.ndarray, np.ndarray]: + return np.array(self.bucket_counts), np.array(self.bucket_limits) + + +def sharded_histogram(a: NamedArray, bins: int | ArrayLike = 10) -> tuple[jnp.ndarray, jnp.ndarray]: + """ + As [jax.numpy.histogram](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.histogram.html#jax.numpy.histogram), + except: + + * It preserves sharding + * It only works with NamedArrays + * It is more performant on TPUs + + Credit to @aphoh for the original implementation, though that one crashes on TPUs due to some kind of driver bug + """ + edges = jnp.histogram_bin_edges(a.array, bins=bins) + return _shardmap_histogram(a, edges), edges + + +def _single_shard_histogram(a, bins, reduce_mesh): + """Modified version of jax.numpy.histogram that returns integer counts instead of using the datatype of the input. + Also avoids searchsorted, which is slow on TPUs. + Args: + a (Array): input array + bins (Array): bins to use for histogram + Returns: + Array: counts. has length len(bins) - 1 + """ + a = a.flatten() + + bin_idx = (a[..., None] >= bins[:-1]).astype(jnp.int32) & (a[..., None] < bins[1:]).astype(jnp.int32) + counts = bin_idx.sum(axis=0, dtype=jnp.int32) + + if len(reduce_mesh): + counts = jax.lax.psum(counts, axis_name=reduce_mesh) + return counts + + +def _shardmap_histogram(a: NamedArray, bins): + mesh = hax.partitioning._get_mesh() + spec = hax.partitioning.pspec_for_axis(a.axes) + flattened_spec = _flattened_spec(spec) + shard_h = shard_map( + functools.partial(_single_shard_histogram, reduce_mesh=flattened_spec), + mesh=mesh, + in_specs=(spec, PartitionSpec(None)), + out_specs=PartitionSpec( + None, + ), + ) + res = shard_h(a.array, bins) + + # the filter misses the last bin, so we need to add it + if res.size >= 1: + res = res.at[-1].add(1) + return res + + +def _flattened_spec(spec): + out = [] + for s in spec: + if isinstance(s, tuple): + out.extend(s) + elif s is None: + pass + else: + out.append(s) + + return tuple(out) diff --git a/src/levanter/tracker/tensorboard.py b/src/levanter/tracker/tensorboard.py index e819d6459..e898bf045 100644 --- a/src/levanter/tracker/tensorboard.py +++ b/src/levanter/tracker/tensorboard.py @@ -5,8 +5,11 @@ from typing import Any, Optional import fsspec +import jax +import numpy as np from levanter.tracker import Tracker, TrackerConfig +from levanter.tracker.histogram import Histogram pylogger = logging.getLogger(__name__) @@ -21,13 +24,36 @@ class TensorboardTracker(Tracker): def __init__(self, writer: "SummaryWriter"): self.writer = writer - def log_hyperparameters(self, hparams: dict[str, Any]): + def log_hyperparameters(self, hparams: typing.Mapping[str, Any]): self.writer.add_hparams(hparams, {"dummy": 0}) - def log(self, metrics: dict[str, Any], *, step, commit=None): + def log(self, metrics: typing.Mapping[str, Any], *, step, commit=None): del commit - for k, v in metrics.items(): - self.writer.add_scalar(k, v, step) + metrics = _flatten_nested_dict(metrics) + for k, value in metrics.items(): + if isinstance(value, jax.Array): + if value.ndim == 0: + value = value.item() + else: + value = np.array(value) + elif isinstance(value, Histogram): + num = value.num + if hasattr(num, "item"): + num = num.item() + self.writer.add_histogram_raw( + k, + min=value.min.item(), + max=value.max.item(), + num=num, + sum=value.sum.item(), + sum_squares=value.sum_squares.item(), + bucket_limits=np.array(value.bucket_limits).tolist(), + bucket_counts=np.concatenate([[0], np.array(value.bucket_counts)]).tolist(), + global_step=step, + ) + continue + + self.writer.add_scalar(k, value, global_step=step) def log_summary(self, metrics: dict[str, Any]): for k, v in metrics.items(): diff --git a/src/levanter/tracker/tracker.py b/src/levanter/tracker/tracker.py index 99fd217e5..820b69983 100644 --- a/src/levanter/tracker/tracker.py +++ b/src/levanter/tracker/tracker.py @@ -14,10 +14,10 @@ class Tracker(abc.ABC): The name is borrowed from HF Accelerate. Examples: - >>> from levanter.tracker import current_tracker, log_metrics + >>> from levanter.tracker import current_tracker, log >>> from levanter.tracker.wandb import WandbTracker >>> with current_tracker(WandbTracker()): - ... log_metrics({"foo": 1}, step=0) + ... log({"foo": 1}, step=0) """ name: str @@ -27,7 +27,7 @@ def log_hyperparameters(self, hparams: dict[str, Any]): pass @abc.abstractmethod - def log(self, metrics: dict[str, typing.Any], *, step: Optional[int], commit: Optional[bool] = None): + def log(self, metrics: typing.Mapping[str, typing.Any], *, step: Optional[int], commit: Optional[bool] = None): """ Log metrics to the tracker. Step is always required. @@ -77,7 +77,7 @@ def log_hyperparameters(self, hparams: dict[str, Any]): for tracker in self.loggers: tracker.log_hyperparameters(hparams) - def log(self, metrics: dict[str, Any], *, step, commit=None): + def log(self, metrics: typing.Mapping[str, Any], *, step, commit=None): for tracker in self.loggers: tracker.log(metrics, step=step, commit=commit) @@ -119,7 +119,7 @@ class NoopTracker(Tracker): def log_hyperparameters(self, hparams: dict[str, Any]): pass - def log(self, metrics: dict[str, Any], *, step, commit: Optional[bool] = None): + def log(self, metrics: typing.Mapping[str, Any], *, step, commit: Optional[bool] = None): pass def log_summary(self, metrics: dict[str, Any]): diff --git a/src/levanter/tracker/tracker_fns.py b/src/levanter/tracker/tracker_fns.py index 1f2203df4..2b1b9b598 100644 --- a/src/levanter/tracker/tracker_fns.py +++ b/src/levanter/tracker/tracker_fns.py @@ -9,9 +9,11 @@ import draccus import jax +from jaxtyping import Scalar from levanter.tracker import CompositeTracker, Tracker from levanter.tracker.helpers import hparams_to_dict +from levanter.tracker.histogram import Histogram from levanter.tracker.tensorboard import TensorboardTracker from levanter.tracker.wandb import WandbTracker from levanter.utils.jax_utils import is_inside_jit @@ -19,16 +21,18 @@ logger = logging.getLogger(__name__) - _global_tracker: Optional["Tracker"] = None +LoggableValues: typing.TypeAlias = Scalar | jax.Array | str | dict | Histogram + -def log_metrics(metrics: dict[str, Any], *, step: Optional[int], commit: Optional[bool] = None): +def log(metrics: typing.Mapping[str, LoggableValues | Any], *, step: Optional[int], commit: Optional[bool] = None): """ Log metrics to the global tracker. Args: - metrics: Metrics to log + metrics: Metrics to log. We use LoggableValues just to give you a sense of what you can log. Backends may + support additional types. step: Step to log at commit: Whether to commit the metrics. If None, uses the default for the tracker. """ @@ -40,13 +44,24 @@ def log_metrics(metrics: dict[str, Any], *, step: Optional[int], commit: Optiona # we're inside a jit, so we need to log from the host if commit: raise ValueError("Cannot commit from inside jit") - jit_log_metrics(metrics, step=step) + jit_log(metrics, step=step) else: # TODO: do we need to coerce to np here? - _global_tracker.log(metrics, step=step) + _global_tracker.log(metrics, step=step, commit=commit) + + +# deprecated in favor of log() +def log_metrics( + metrics: typing.Mapping[str, LoggableValues | Any], *, step: Optional[int], commit: Optional[bool] = None +): + """ + Deprecated. Use log instead. + """ + warnings.warn("log_metrics is deprecated in favor of log", DeprecationWarning) + log(metrics, step=step, commit=commit) -def _no_throw_log_metrics(metrics: dict[str, Any], *, step: Optional[int], commit: Optional[bool] = None): +def _do_jit_log(metrics, *, step=None): try: if _global_tracker is None: warnings.warn("No global tracker set") @@ -56,9 +71,12 @@ def _no_throw_log_metrics(metrics: dict[str, Any], *, step: Optional[int], commi logger.exception("Error logging metrics") -def jit_log_metrics(metrics, *, step=None): +def jit_log(metrics, *, step=None): """uses jax effect callback to log to wandb from the host""" - jax.debug.callback(_no_throw_log_metrics, metrics, step=step) + # This doesn't work reliably on TPU, so we disable it for now + jax.debug.callback(_do_jit_log, metrics, step=step) + # global _jit_log_dict + # _jit_log_dict.update(metrics) def log_summary(metrics: dict[str, Any]): @@ -129,10 +147,10 @@ def set_global_tracker(tracker: Tracker): force: Whether to force setting the global tracker even if it is already set Examples: - >>> from levanter.tracker import set_global_tracker, log_metrics + >>> from levanter.tracker import set_global_tracker, log >>> from levanter.tracker.wandb import WandbTracker >>> set_global_tracker(WandbTracker()) - >>> log_metrics({"foo": 1}, step=0) + >>> log({"foo": 1}, step=0) """ global _global_tracker if _global_tracker is not None: @@ -166,10 +184,10 @@ def current_tracker( If a tracker is provided, returns a context manager that sets the global tracker to the provided tracker when used. Examples: - >>> from levanter.tracker import current_tracker, log_metrics + >>> from levanter.tracker import current_tracker, log >>> from levanter.tracker.wandb import WandbTracker >>> with current_tracker(WandbTracker()): - ... log_metrics({"foo": 1}, step=0) + ... log({"foo": 1}, step=0) ... current_tracker().log({"foo": 2}, step=1) """ global _global_tracker @@ -207,10 +225,10 @@ def get_tracker(name: str) -> Tracker: The tracker with the provided name Examples: - >>> from levanter.tracker import get_tracker, log_metrics + >>> from levanter.tracker import get_tracker, log >>> from levanter.tracker.wandb import WandbTracker >>> with current_tracker(WandbTracker()): - ... log_metrics({"foo": 1}, step=0) + ... log({"foo": 1}, step=0) ... get_tracker("wandb").log_metrics({"foo": 2}, step=1) """ tracker = current_tracker() diff --git a/src/levanter/tracker/wandb.py b/src/levanter/tracker/wandb.py index 981bebf83..8e9493e61 100644 --- a/src/levanter/tracker/wandb.py +++ b/src/levanter/tracker/wandb.py @@ -7,11 +7,13 @@ from typing import Any, List, Optional, Union import jax +import numpy as np from draccus import field from git import InvalidGitRepositoryError, NoSuchPathError, Repo from levanter.tracker import Tracker from levanter.tracker.helpers import generate_pip_freeze, infer_experiment_git_root +from levanter.tracker.histogram import Histogram from levanter.tracker.tracker import TrackerConfig from levanter.utils import jax_utils @@ -48,9 +50,9 @@ def __init__(self, run: Optional[WandbRun]): self._last_warning_step = -500 def log_hyperparameters(self, hparams: dict[str, Any]): - self.run.config.update(hparams, allow_val_change=True) + self.run.config.update(_convert_values_to_loggable(hparams), allow_val_change=True) - def log(self, metrics: dict[str, Any], *, step, commit=None): + def log(self, metrics: typing.Mapping[str, Any], *, step, commit=None): if step is None and not commit: step = self.run.step @@ -64,10 +66,10 @@ def log(self, metrics: dict[str, Any], *, step, commit=None): step = int(step) - self.run.log(metrics, step=step, commit=commit) + self.run.log(_convert_values_to_loggable(metrics), step=step, commit=commit) - def log_summary(self, metrics: dict[str, Any]): - self.run.summary.update(metrics) + def log_summary(self, metrics: typing.Mapping[str, Any]): + self.run.summary.update(_convert_values_to_loggable(metrics)) def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None): self.run.log_artifact(artifact_path, name=name, type=type) @@ -77,6 +79,29 @@ def finish(self): self.run.finish() +def _convert_values_to_loggable(values: typing.Mapping[str, Any]): + def convert_value_to_loggable(value: Any): + if isinstance(value, (list, tuple)): + return [convert_value_to_loggable(v) for v in value] + elif isinstance(value, typing.Mapping): + return {k: convert_value_to_loggable(v) for k, v in value.items()} + elif isinstance(value, jax.Array): + if value.ndim == 0: + return value.item() + else: + return np.array(value) + elif isinstance(value, Histogram): + import wandb + + counts, limits = value.to_numpy_histogram() + + return wandb.Histogram(np_histogram=(counts.tolist(), limits.tolist())) + else: + return value + + return convert_value_to_loggable(values) + + def is_wandb_available(): try: import wandb @@ -93,7 +118,7 @@ class WandbConfig(TrackerConfig): """ entity: Optional[str] = None # An entity is a username or team name where you send runs - project: Optional[str] = None # The name of the project where you are sending the enw run. + project: Optional[str] = "levanter" # The name of the project where you are sending the enw run. name: Optional[str] = None # A short display name for this run, which is how you'll identify this run in the UI. tags: List[str] = field(default_factory=list) # Will populate the list of tags on this run in the UI. id: Optional[str] = None # A unique ID for this run, used for resuming. It must be unique in the project diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index fb353592d..82f32422a 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -9,21 +9,7 @@ from dataclasses import dataclass from functools import cached_property from pathlib import Path -from typing import ( - Any, - Callable, - Dict, - Generic, - Iterable, - List, - Mapping, - Optional, - Protocol, - Sequence, - Tuple, - TypeVar, - Union, -) +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Protocol, Sequence, Tuple, TypeVar, Union import equinox as eqx import jax @@ -46,6 +32,7 @@ import levanter.tracker.wandb import levanter.utils.logging from levanter import tracker +from levanter.callbacks import Callback, CBInfo, JitCallback, LambdaCallback, M, S, StepInfo from levanter.checkpoint import CheckpointerConfig, load_checkpoint_or_initialize from levanter.config import JsonAtom from levanter.data import AsyncDataset, DataLoader @@ -54,16 +41,14 @@ from levanter.tracker import TrackerConfig, capture_time from levanter.trainer_state import TrainerState, saveable_training_mask from levanter.utils import cloud_utils, fsspec_utils -from levanter.utils.jax_utils import create_fsdp_mesh +from levanter.utils.jax_utils import create_fsdp_mesh, zeros_like_tree from levanter.utils.tree_utils import inference_mode from levanter.utils.types import ComputeLossFunction, FilterSpec logger = pylogging.getLogger(__name__) -M = TypeVar("M") # Model X = TypeVar("X") # Input -S = TypeVar("S", bound=TrainerState) DEFAULT_JAX_CONFIG: Dict[str, JsonAtom] = { "jax_threefry_partitionable": True, @@ -74,42 +59,68 @@ # A note on the semantics of "step" vs "next_step": # The "step" of a TrainerState is the state after `step` steps have been taken. # A "StepInfo"'s step is the step that was just completed. If you want the next step, use `next_step`. -@dataclass -class StepInfo(Generic[S]): - state: S - loss: float - step_duration: float - model = property(lambda self: self.state.model) - opt_state = property(lambda self: self.state.opt_state) - step = property(lambda self: int(self.state.step) - 1) - """ - The step that was just completed. If you want the next step, use `next_step`. - """ - next_step = property(lambda self: int(self.state.step)) +@dataclass +class _Hook: + fn: Callback + every: int @dataclass -class _Hook: - fn: Callable[[StepInfo], None] +class _JitHook: + fn: JitCallback every: int class TrainerHooks: hooks: List[_Hook] + stateful_hooks: List[_JitHook] def __init__(self): self.hooks = [] + self.stateful_hooks = [] def run_hooks(self, info: StepInfo, force: bool = False): for hook in self.hooks: if force or info.step % hook.every == 0: - hook.fn(info) + hook.fn.on_step(info, force=force) + + def run_jit_hooks_outside_step(self, info: StepInfo, cb_infos: Sequence[PyTree], force: bool = False): + for s_hook, cb_info in zip(self.stateful_hooks, cb_infos): + if force or (info.step % s_hook.every == 0): + s_hook.fn.on_step(info, cb_info) + + def run_jit_hooks(self, state: TrainerState, grad: M, force: bool = False) -> tuple[PyTree, ...]: + hook: _JitHook + hook_infos = [] + for hook in self.stateful_hooks: + hook_shape = eqx.filter_eval_shape(hook.fn.inside_step, state, grad) + new_s = jax.lax.cond( + force or (state.step % hook.every == 0), + lambda: hook.fn.inside_step(state, grad), + lambda: zeros_like_tree(hook_shape), + ) + hook_infos.append(new_s) + + return tuple(hook_infos) + + def add_hook(self, fn: Optional[Callable[[StepInfo], Any] | JitCallback | Callback] = None, *, every: int = 1): + def decorator(fn): + is_something = False + + if isinstance(fn, Callback): + self.hooks.append(_Hook(fn, every)) + is_something = True - def add_hook(self, fn: Optional[Callable[[StepInfo], Any]] = None, *, every: int = 1): - def decorator(fn: Callable[[StepInfo], None]): - self.hooks.append(_Hook(fn, every)) + if isinstance(fn, JitCallback): + self.stateful_hooks.append(_JitHook(fn, every)) + is_something = True + + if not is_something: + if not callable(fn): + raise ValueError(f"fn must be callable, got {fn}") + self.hooks.append(_Hook(LambdaCallback(fn), every)) if fn is None: return decorator @@ -214,11 +225,19 @@ def num_train_steps(self) -> int: def add_hook(self, fn: Callable[[StepInfo], Any], *, every: int = 1): ... + @typing.overload + def add_hook(self, fn: JitCallback, *, every: int = 1): + ... + + @typing.overload + def add_hook(self, fn: Callback, *, every: int = 1): + ... + @typing.overload def add_hook(self, *, every: int = 1): ... - def add_hook(self, fn: Optional[Callable[[StepInfo], Any]] = None, *, every: int = 1): + def add_hook(self, fn: Optional[Callable[[StepInfo], Any] | Callback | JitCallback] = None, *, every: int = 1): return self.hooks.add_hook(fn, every=every) def run_hooks(self, info: StepInfo, force: bool = False): @@ -365,14 +384,31 @@ def train_step(self, state: S, *batch: X, **batch_kwargs) -> StepInfo[S]: """ Performs a single training step. """ + # jit hooks impose a nontrivial cost even when they're not run (since they defeat some compiler optimizations) + # so we avoid running them when they're not needed + # this results in two compiles, but the cost of the second compile is worth it + hooks_this_time = any(state.step % h.every == 0 for h in self.hooks.stateful_hooks) + with capture_time() as step_time: - loss, new_state = self._jit_train_step_fn(state, *batch, **batch_kwargs) - # force the loss so timing numbers are accurate. laziness isn't going to help here (i think?) + if hooks_this_time: + loss, new_state, cb_states = self._jit_train_step_fn(state, batch, batch_kwargs) + # force the loss so timing numbers are accurate. laziness isn't going to help here (i think?) + else: + loss, new_state = self._jit_train_step_fn_no_hook(state, batch, batch_kwargs) loss = loss.item() # type: ignore - return StepInfo(new_state, loss, step_time()) + info = StepInfo(new_state, loss, step_time()) + + with capture_time() as hook_time: + self.run_hooks(info) + if hooks_this_time: + self.hooks.run_jit_hooks_outside_step(info, cb_states) + + levanter.tracker.log({"throughput/hook_time": hook_time()}, step=info.step) + + return info - def training_steps(self, state: S, train_loader, run_hooks: bool = True) -> typing.Iterator[StepInfo[S]]: + def training_steps(self, state: S, train_loader) -> typing.Iterator[StepInfo[S]]: """ Generator that yields training steps and runs hooks. """ @@ -384,26 +420,19 @@ def training_steps(self, state: S, train_loader, run_hooks: bool = True) -> typi info = self.train_step(state, example) state = info.state - if run_hooks: - with capture_time() as hook_time: - self.run_hooks(info) - - levanter.tracker.log_metrics({"throughput/hook_time": hook_time()}, step=info.step) - - levanter.tracker.log_metrics({"throughput/loading_time": loading_time()}, step=info.step) + levanter.tracker.log({"throughput/loading_time": loading_time()}, step=info.step) yield info - def train(self, state: S, train_loader: Iterable[X], run_hooks: bool = True) -> StepInfo[S]: + def train(self, state: S, train_loader: Iterable[X]) -> StepInfo[S]: """ Performs training until the number of steps is reached. """ - for info in self.training_steps(state, train_loader, run_hooks=run_hooks): + for info in self.training_steps(state, train_loader): pass - if run_hooks: - # force hooks to run at the end - self.run_hooks(info, force=True) + # force hooks to run at the end + self.run_hooks(info, force=True) return info @@ -485,12 +514,27 @@ def _jit_train_step_fn(self): donate_args=(True,), ) - def _train_step(self, state: S, *batch, **batch_kwargs) -> tuple[Scalar, S]: + @cached_property + def _jit_train_step_fn_no_hook(self): + return named_jit( + functools.partial(self._train_step, _no_hooks=True), + axis_resources=self.parameter_axis_mapping, + out_axis_resources=self.parameter_axis_mapping, + donate_args=(True,), + ) + + def _train_step( + self, state: S, batch, batch_kwargs, _no_hooks=False + ) -> tuple[Scalar, S, Sequence[CBInfo]] | tuple[Scalar, S]: key, new_key = jax.random.split(state.training_key) model = inference_mode(state.model, False) loss, grads = self._compute_gradients_microbatched(self.loss_fn, model, *batch, **batch_kwargs, key=key) + with hax.axis_mapping(self.parameter_axis_mapping): + if not _no_hooks: + hook_infos = self.hooks.run_jit_hooks(state, grads, force=False) + # Sophia needs to be able to access the loss function in the optimizer def obj_fun(trainable_model): model = eqx.combine(trainable_model, state.model) @@ -500,7 +544,10 @@ def obj_fun(trainable_model): new_state = state.take_step(grads, obj_fun=obj_fun) new_state = hax.shard(new_state, self.parameter_axis_mapping) - return loss, new_state + if _no_hooks: + return loss, new_state + else: + return loss, new_state, hook_infos def _compute_gradients_microbatched(self, loss_fn, model: M, *batch, **batch_kwargs) -> tuple[Scalar, M]: grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=False) diff --git a/src/levanter/utils/jax_utils.py b/src/levanter/utils/jax_utils.py index 09c101b82..be77a1d99 100644 --- a/src/levanter/utils/jax_utils.py +++ b/src/levanter/utils/jax_utils.py @@ -1,4 +1,5 @@ import contextlib +import functools import json import warnings from dataclasses import fields @@ -13,11 +14,13 @@ from jaxtyping import PRNGKeyArray, PyTree import haliax as hax +from haliax import is_named_array from haliax.jax_utils import is_jax_array_like -from haliax.partitioning import ResourceAxis +from haliax.partitioning import ResourceAxis, ResourceMapping X = TypeVar("X") +T = TypeVar("T", bound=PyTree) def jnp_to_python(a: jnp.ndarray): @@ -340,3 +343,29 @@ def estimated_free_device_memory(device) -> Optional[float]: in_use = stats.get("bytes_in_use", 0) return (limit - in_use) // (1024.0**3) + + +def zeros_like_tree(tree: T, axis_mapping: Optional[ResourceMapping] = None, dtype: Optional[jnp.dtype] = None) -> T: + """ + Creates a tree of zeros with the same structure as the input tree. If the input tree contains NamedArrays, then + those will be sharded according to the axis_mapping (or the context axis mapping if not provided). + """ + _zeros = functools.partial(_zeros_like, axis_mapping, dtype) + acc = jax.tree_util.tree_map(_zeros, tree, is_leaf=is_named_array) + return acc + + +def _zeros_like(mapping, dtype, n): + if isinstance(n, hax.NamedArray): + return hax.shard(hax.zeros_like(n, dtype=dtype), mapping) + elif is_jax_array_like(n): + return jnp.zeros_like(n, dtype) + else: + assert jnp.isscalar(n) + if dtype is None: + # if it's a nan, we want to go to 0 + if n != n: + return 0 + return n - n + else: + return jnp.zeros((), dtype=dtype) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 98e8e1d58..037384c51 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -16,6 +16,7 @@ import haliax as hax from haliax import Axis +from levanter.callbacks import StepInfo from levanter.checkpoint import ( Checkpointer, CheckpointInterval, @@ -25,7 +26,6 @@ load_metadata, save_checkpoint, ) -from levanter.trainer import StepInfo from levanter.trainer_state import TrainerState from test_utils import MLP, arrays_only, assert_trees_not_close diff --git a/tests/test_doremi.py b/tests/test_doremi.py index bbab04f52..3ad4aa9ab 100644 --- a/tests/test_doremi.py +++ b/tests/test_doremi.py @@ -143,7 +143,7 @@ def compute_loss_fn(model, example, reduction=hax.mean, reduction_axis=None, key optimizer = optax.adam(1e-2) - trainer = Trainer(tiny_trainer_config, optimizer, compute_loss_fn) + trainer = Trainer(tiny_trainer_config, optimizer, compute_loss_fn, add_default_hooks=False) def fit_to_dataset(dataset: AsyncDataset): initial_model = init_model() @@ -154,8 +154,7 @@ def fit_to_dataset(dataset: AsyncDataset): loss = 0.0 - # state = trainer.train(state, loader, run_hooks=False) - for state in trainer.training_steps(state, loader, run_hooks=False): + for state in trainer.training_steps(state, loader): if state.step >= 200: loss += state.loss diff --git a/tests/test_histogram.py b/tests/test_histogram.py new file mode 100644 index 000000000..f2ef4fd0a --- /dev/null +++ b/tests/test_histogram.py @@ -0,0 +1,45 @@ +import jax +import numpy as np +from jax.random import PRNGKey +from jax.sharding import Mesh + +import haliax as hax +from haliax.partitioning import ResourceAxis + +import levanter.tracker.histogram +from test_utils import skip_if_not_enough_devices + + +def test_sharded_histogram_simple(): + mesh = Mesh((jax.devices()), (ResourceAxis.DATA,)) + + Batch = hax.Axis("batch", 64) + Feature = hax.Axis("feature", 128) + + with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA}): + a = hax.random.normal(PRNGKey(1), (Batch, Feature)) + a = hax.shard(a) + hist, bins = levanter.tracker.histogram.sharded_histogram(a, bins=32) + + hist_normal, bins_normal = jax.numpy.histogram(a.array, bins=32) + + assert jax.numpy.allclose(hist, hist_normal) + assert jax.numpy.allclose(bins, bins_normal) + + +@skip_if_not_enough_devices(2) +def test_sharded_histogram_tp(): + mesh = Mesh(np.array(jax.devices()).reshape(-1, 2), (ResourceAxis.DATA, ResourceAxis.MODEL)) + + Batch = hax.Axis("batch", 64) + Feature = hax.Axis("feature", 128) + + with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA, "feature": ResourceAxis.MODEL}): + a = hax.random.normal(PRNGKey(0), (Batch, Feature)) * 100 + a = hax.shard(a) + hist, bins = levanter.tracker.histogram.sharded_histogram(a, bins=64) + + jnp_hist, jnp_bins = jax.numpy.histogram(a.array, bins=64) + + assert jax.numpy.allclose(hist, jnp_hist) + assert jax.numpy.allclose(bins, jnp_bins) diff --git a/tests/test_lora.py b/tests/test_lora.py index b6933f935..6250535b3 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -12,6 +12,7 @@ import haliax.nn as hnn from haliax.quantization import DefaultDotGeneralOp, DotGeneralOp +from levanter.callbacks import StepInfo from levanter.checkpoint import Checkpointer from levanter.compat.hf_checkpoints import HFCheckpointConverter from levanter.lora import ( @@ -26,7 +27,6 @@ ) from levanter.models.attention import AttentionMask from levanter.models.gpt2 import Gpt2Config, Gpt2LMHeadModel -from levanter.trainer import StepInfo from levanter.trainer_state import TrainerState from levanter.utils.tree_utils import inference_mode from test_utils import skip_if_no_torch