diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index 15f16a203..5c243946e 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -7,7 +7,7 @@ import urllib.parse from dataclasses import dataclass from datetime import timedelta -from typing import Callable, List, Optional, Sequence, Tuple, TypeVar, Union +from typing import Callable, List, Optional, Sequence, TypeVar, Union import equinox import fsspec @@ -28,8 +28,7 @@ PathLike = Union[str, pathlib.Path] -M = TypeVar("M") -S = TypeVar("S") +M = TypeVar("M", bound=PyTree) @dataclass(frozen=True) @@ -102,19 +101,16 @@ def __init__( def load_checkpoint( self, - model: M, - training_state: S, + state: M, path: Optional[PathLike] = None, *, discover_latest: bool = True, axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, mesh: Optional[haliax.partitioning.Mesh] = None, - ) -> Optional[Tuple[M, S, int]]: + ) -> Optional[M]: if path is None: path = self.base_path - return load_checkpoint( - model, training_state, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh - ) + return load_checkpoint(state, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh) def load_model( self, @@ -124,16 +120,17 @@ def load_model( discover_latest: bool = True, axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, mesh: Optional[haliax.partitioning.Mesh] = None, - ) -> Optional[Tuple[M, int]]: - if path is None: - path = self.base_path - ckpt = load_checkpoint( - model, None, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh + ) -> Optional[M]: + """ + Convenience method/holdover from previous API for loading checkpoints. + Loads just the model assuming the model is in the `model` subdir of the discovered checkpoint. + """ + ret_dict = self.load_checkpoint( + {"model": model}, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh ) - if ckpt is None: + if ret_dict is None: return None - model, _, step = ckpt - return model, step + return ret_dict["model"] def on_step(self, info, force: bool = False): step = info.step @@ -219,10 +216,9 @@ def _rm_checkpoint(self, checkpoint): def save_checkpoint(self, info, destination: str): path = os.path.join(self.base_path, destination) logger.info(f"Saving checkpoint at step {info.step} to {path}") - model = equinox.filter(info.model, self.keep_params) + state = equinox.filter(info.state, info.state.is_trainable) save_checkpoint( - model=model, - training_state=(info.opt_state, info.next_key), + state, step=info.step, checkpoint_path=path, ) @@ -231,7 +227,7 @@ def save_checkpoint(self, info, destination: str): logger.info(f"Saved checkpoint at step {info.step} to {path}. Save time is {self._last_save_time}") -def save_checkpoint(model, training_state, step: int, checkpoint_path: PathLike): +def save_checkpoint(tree: M, step: int, checkpoint_path: PathLike): """ Save a checkpoint to a given path using TensorStore. If exist_ok is True, the checkpoint will be saved even if a checkpoint already exists at the given path. @@ -249,10 +245,7 @@ def save_checkpoint(model, training_state, step: int, checkpoint_path: PathLike) fs, plain_path = _get_fs_and_plain_path(checkpoint_path) fs.makedirs(plain_path, exist_ok=True) - tree_serialize_leaves_tensorstore(os.path.join(checkpoint_path, "model"), model) - if training_state is not None: - tree_serialize_leaves_tensorstore(os.path.join(checkpoint_path, "training_state"), training_state) - + tree_serialize_leaves_tensorstore(checkpoint_path, tree) save_metadata(checkpoint_path, fs, step) logger.info(f"Saved checkpoint for step {step}") @@ -271,22 +264,30 @@ def save_metadata(checkpoint_path, fs, step): def load_checkpoint( - model: M, - training_state: S, + tree: M, checkpoint_path: PathLike, *, + subpath: Optional[str] = None, discover_latest=True, axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, mesh: Optional[jax.sharding.Mesh] = None, -) -> Optional[Tuple[M, S, int]]: +) -> M: """ - Load a checkpoint from a given path. - - Returns the loaded model state, training state, and step. If discover_latest is True, - the latest checkpoint in the given path will be loaded. Otherwise, the checkpoint at - the given path will be loaded. If no checkpoint is found, returns None + Load a checkpoint from a given path. If discover_latest is True, then the latest checkpoint + in a subdirectory of the given path will be loaded. If subpath is not None, then the checkpoint + loads only that subpath of the checkpoint. This is useful for loading, e.g., just the model and not + the entire training state. + + Args: + tree: an exemplar of the tree to load. Can be a PyTree[ShapeDTypeStruct] instead of a PyTree[Any] + checkpoint_path: the path to load the checkpoint from + subpath: the subpath to load from the checkpoint + discover_latest: whether to discover the latest checkpoint in the given path + axis_mapping: the axis mapping to use for loading the checkpoint + mesh: the mesh to use for loading the checkpoint + Returns: + the loaded checkpoint, with the same structure as the exemplar tree - If training_state is None, no training state will be loaded. """ fs: AbstractFileSystem fs, _ = _get_fs_and_plain_path(checkpoint_path) @@ -297,28 +298,52 @@ def load_checkpoint( checkpoint_path = discover_latest_checkpoint(checkpoint_path) # type: ignore if checkpoint_path is None or not fs.exists(checkpoint_path): - return None + raise FileNotFoundError(f"Could not find checkpoint at {checkpoint_path}") logger.info(f"Loading checkpoint from {checkpoint_path}") metadata = load_metadata(checkpoint_path, fs) - model = tree_deserialize_leaves_tensorstore( - os.path.join(checkpoint_path, "model"), model, axis_mapping=axis_mapping, mesh=mesh - ) + if subpath: + checkpoint_path = os.path.join(checkpoint_path, subpath) - if training_state is None: - training_state = None - else: - training_state = tree_deserialize_leaves_tensorstore( - os.path.join(checkpoint_path, "training_state"), training_state, axis_mapping=axis_mapping, mesh=mesh - ) + try: + tree = tree_deserialize_leaves_tensorstore(checkpoint_path, tree, axis_mapping=axis_mapping, mesh=mesh) + return tree + except: # noqa + from levanter.trainer import TrainerState - return model, training_state, metadata["step"] + if not isinstance(tree, TrainerState): + raise + else: + logger.warning("Attempting to load old-style checkpoint") + model, training_state = tree.model, (tree.opt_state, tree.training_key) + + model = tree_deserialize_leaves_tensorstore( + os.path.join(checkpoint_path, "model"), model, axis_mapping=axis_mapping, mesh=mesh + ) + + if training_state is None: + opt_state = None + key = None + else: + training_state = tree_deserialize_leaves_tensorstore( + os.path.join(checkpoint_path, "training_state"), + training_state, + axis_mapping=axis_mapping, + mesh=mesh, + ) + opt_state, key = training_state + + # TODO: pretty sure this is right, but should verify + step = metadata["step"] + new_state = dataclasses.replace( + tree, _step=step + 1, model=model, opt_state=opt_state, training_key=key # type: ignore + ) + return new_state def load_metadata(checkpoint_path, fs=None): if fs is None: - fs: AbstractFileSystem fs, _, _ = fsspec.get_fs_token_paths(str(checkpoint_path)) with fs.open(os.path.join(checkpoint_path, "metadata.json")) as metadata_in: metadata = json.load(metadata_in) @@ -381,13 +406,12 @@ class CheckpointerConfig: def expanded_path(self, run_id): return os.path.expanduser(os.path.join(self.base_path, run_id)) - def create(self, run_id, keep_params: PyTree[FilterSpec] = True) -> Checkpointer: + def create(self, run_id) -> Checkpointer: keeps = [CheckpointInterval(**k) for k in self.keep] return Checkpointer( base_path=self.expanded_path(run_id), save_interval=self.save_interval, step_policies=keeps, - keep_params=keep_params, ) def __post_init__(self): diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index f5faa9b36..c1d24c1a0 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -1231,17 +1231,6 @@ def priority_fn(shard_idx, chunk_idx): ray.get(reader_actor.add_work_group.remote(work_item)) - # reader = _alternating_shard_reader.remote( - # name, - # self_ref, - # writer, - # source, - # shard_group, - # priority_fn, - # processor_actor, - # processor.batch_size, - # rows_per_chunk, - # ) self._shard_readers.append(reader_actor) def new_chunk(self, shard_name: str, *chunks: ChunkMetadata): diff --git a/src/levanter/main/eval_lm.py b/src/levanter/main/eval_lm.py index 9b056b950..c7976ad41 100644 --- a/src/levanter/main/eval_lm.py +++ b/src/levanter/main/eval_lm.py @@ -85,14 +85,10 @@ def compute_loss(model: LmHeadModel, example: LmExample): with use_cpu_device(): model = eqx.filter_eval_shape(config.model.build, Vocab, key=key) # TODO: don't load the entire checkpoint into CPU memory when we only need our share of the model - ckpt = load_checkpoint(model, None, config.checkpoint_path) - - assert ckpt is not None - model, _, _ = ckpt + model = load_checkpoint(model, config.checkpoint_path, subpath="model") model = hax.shard_with_axis_mapping(model, parameter_axis_mapping) - # TODO: switch to throwing instead of returning None loss = callbacks.eval_loss_loop(compute_loss, model, eval_loader, max_batches=total) del model diff --git a/src/levanter/main/export_lm_to_hf.py b/src/levanter/main/export_lm_to_hf.py index 50a8e4b92..7fd4d073d 100644 --- a/src/levanter/main/export_lm_to_hf.py +++ b/src/levanter/main/export_lm_to_hf.py @@ -51,10 +51,9 @@ def main(config: ConvertLmConfig): model: LmHeadModel = eqx.filter_eval_shape(config.model.build, Vocab, key=key) trainable, non_trainable = eqx.partition(model, is_inexact_arrayish) # TODO: don't load the entire checkpoint into CPU memory when we only need our share of the model - ckpt = load_checkpoint(trainable, None, config.checkpoint_path) + trainable = load_checkpoint(trainable, config.checkpoint_path, subpath="model") - assert ckpt is not None - trainable, _, _ = ckpt + assert trainable is not None model = eqx.combine(trainable, non_trainable) if config.override_vocab_size: diff --git a/src/levanter/main/lora_lm.py b/src/levanter/main/lora_lm.py index 5120c9e22..6b845b516 100644 --- a/src/levanter/main/lora_lm.py +++ b/src/levanter/main/lora_lm.py @@ -93,7 +93,10 @@ def compute_loss(model, example: LmExample, key=None): state = trainer.initial_state(training_key, model=model) all_param_count = parameter_count(state.model) - just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) + # TODO: remove this once we put this in trainer itself + just_lora_params = parameter_count( + levanter.trainer._partition_trainable_params(state.model, lora_param_filter) + ) levanter.tracker.log_summary( { @@ -140,7 +143,7 @@ def compute_loss(model, example: LmExample, key=None): # TODO: implement iter_data.seek(resume_step +1) import tqdm - for _ in tqdm.tqdm(range(state.step + 1), desc="seeking data for resume"): + for _ in tqdm.tqdm(range(state.step), desc="seeking data for resume"): next(iter_data) ## OK, actually run training! diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 2dbd705d5..68f63b987 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -182,7 +182,7 @@ def compute_log_probs(model, example: LmExample): # TODO: implement iter_data.seek(resume_step +1) import tqdm - for _ in tqdm.tqdm(range(state.step + 1), desc="seeking data for resume"): + for _ in tqdm.tqdm(range(state.step), desc="seeking data for resume"): next(train_loader) ## OK, actually run training! diff --git a/src/levanter/main/viz_logprobs.py b/src/levanter/main/viz_logprobs.py index b992cd3f5..ef16a7238 100644 --- a/src/levanter/main/viz_logprobs.py +++ b/src/levanter/main/viz_logprobs.py @@ -81,10 +81,9 @@ def compute_log_probs(model: LmHeadModel, example: LmExample): with use_cpu_device(): model = eqx.filter_eval_shape(config.model.build, Vocab, key=key) # TODO: don't load the entire checkpoint into CPU memory when we only need our share of the model - ckpt = load_checkpoint(model, None, config.checkpoint_path) + model = load_checkpoint(model, config.checkpoint_path, subpath="model") - assert ckpt is not None - model, _, _ = ckpt + assert model is not None model = hax.shard_with_axis_mapping(model, parameter_axis_mapping) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 7d8661c91..41f5d04ab 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -1,5 +1,6 @@ import atexit import copy +import dataclasses import functools import logging as pylogging import os @@ -30,7 +31,6 @@ import jmp import numpy as np from draccus import field -from jax import ShapeDtypeStruct from jax.experimental import multihost_utils from jax.sharding import Mesh from jaxtyping import PRNGKeyArray, PyTree @@ -39,13 +39,13 @@ import haliax as hax from haliax import Axis from haliax.partitioning import ResourceAxis, ResourceMapping, named_jit -from haliax.types import Scalar +from haliax.types import IntScalar, Scalar import levanter.logging import levanter.tracker import levanter.tracker.wandb from levanter import tracker -from levanter.checkpoint import CheckpointerConfig +from levanter.checkpoint import CheckpointerConfig, load_checkpoint from levanter.config import JsonAtom from levanter.data import Dataset, ReplicatedBatchLoader, ShardableDataset, ShardedBatchLoader from levanter.distributed import DistributedConfig, RayConfig @@ -54,7 +54,7 @@ from levanter.tracker import TrackerConfig from levanter.types import FilterSpec from levanter.utils import cloud_utils -from levanter.utils.jax_utils import is_inexact_arrayish +from levanter.utils.jax_utils import as_arrayish, is_inexact_arrayish from levanter.utils.tree_utils import inference_mode @@ -62,7 +62,6 @@ X = TypeVar("X") # Input M = TypeVar("M", bound=PyTree) -S = TypeVar("S", bound=PyTree) DEFAULT_JAX_CONFIG = { "jax_threefry_partitionable": True, @@ -74,14 +73,36 @@ # A "StepInfo"'s step is the step that was just completed. If you want the next step, use `next_step`. -@dataclass -class TrainerState(Generic[M]): - step: int +class TrainerState(eqx.Module, Generic[M]): + """ + This is the state of the trainer. It contains the model, optimizer state, and random key. + It is an equinox Module becaues it is a PyTree that gets passed to the core `train_step` method + of the Trainer. This unfortunately means that `step` is an Array and not an int, hence the IntScalar. + + It's designed to be extended by subclasses. + """ + + _step: IntScalar = eqx.field(converter=lambda x: as_arrayish(x)) model: M opt_state: OptState training_key: PRNGKeyArray + is_trainable: PyTree[FilterSpec] # = eqx.field(static=True) + + @cached_property + def step(self) -> int: + return int(self._step) + + @property + def trainable_model(self) -> M: + return eqx.filter(self.model, self.is_trainable) +S = TypeVar("S", bound=TrainerState) + + +# A note on the semantics of "step" vs "next_step": +# The "step" of a TrainerState is the state after `step` steps have been taken. +# A "StepInfo"'s step is the step that was just completed. If you want the next step, use `next_step`. @dataclass class StepInfo(Generic[M]): state: TrainerState[M] @@ -90,7 +111,6 @@ class StepInfo(Generic[M]): model = property(lambda self: self.state.model) opt_state = property(lambda self: self.state.opt_state) - next_key = property(lambda self: self.state.training_key) step = property(lambda self: self.state.step - 1) """ @@ -172,12 +192,12 @@ def loss_fn(self): Wrapped loss function that casts the model to compute precision and sets the context axis mapping to compute """ - @named_jit(in_axis_resources=self.parameter_axis_mapping, axis_resources=self.compute_axis_mapping) + @named_jit(axis_resources=self.compute_axis_mapping) @functools.wraps(self._raw_loss_function) def fn(model, *batch, **batch_kwargs): with hax.axis_mapping(self.compute_axis_mapping): model = self.mp.cast_to_compute(model) - return self._raw_loss_function(model, *batch, **batch_kwargs) + return _ensure_scalar(self._raw_loss_function(model, *batch, **batch_kwargs)) return fn @@ -273,79 +293,81 @@ def initial_state( raise ValueError("one of model and model_init must be specified") if model is not None: - # we can't just use `lambda: model` because JAX jit can't see captures, but it can see partials - # We can't use plain partials because they aren't pytrees + # we can't just use `lambda: model` because JAX jit can't see captures, but it can see jax partials model_init = jax.tree_util.Partial(lambda m: m, model) + del model assert model_init is not None - model_shape, opt_state_shape = eqx.filter_eval_shape(self._init_model_and_opt_state, model_init) + # first try to load a full trainer state checkpoint + checkpoint_path = self.config.load_checkpoint_path + if checkpoint_path is None: + checkpoint_path = self.config.checkpointer.expanded_path(self.run_id) + + do_load_checkpoint = self.config.load_checkpoint + axis_mapping = self.parameter_axis_mapping + mesh = self.device_mesh + initial_model_path = self.config.initialize_from - # we only checkpoint the trainable parameters, so we need to filter out the non-trainable ones - trainable_model_shape = self.trainable_params_only(model_shape) + # we don't save the full trainer state, so we need to filter out the non-trainable parameters - ckpt = self.maybe_load_checkpoint( - trainable_model_shape, - (opt_state_shape, training_key), - axis_mapping=self.parameter_axis_mapping, - mesh=self.device_mesh, + def init_state_and_model(model_init, training_key, is_trainable): + model = model_init() + state = self._initialize_state_from_scratch(model, training_key, is_trainable) + return state + + trainer_state_shape = eqx.filter_eval_shape( + init_state_and_model, model_init, training_key, self.is_trainable_param ) + saveable_state_shape = _make_saveable_trainer_state(trainer_state_shape, self.is_trainable_param) - if ckpt is not None: - trainable_model, (opt_state, training_key), completed_step = ckpt - if model is not None: - model = eqx.combine(trainable_model, model) - else: - model = eqx.combine(trainable_model, model_shape) - - if any(isinstance(leaf, ShapeDtypeStruct) for leaf in jax.tree_leaves(model)): - # if we're resuming, we need to re-initialize the non-trainable parameters to their original values - non_trainable = named_jit(self._init_non_trainable_params, self.parameter_axis_mapping)(model_init) - model = eqx.combine(trainable_model, non_trainable) - - step = completed_step + 1 - elif self.config.initialize_from is not None: - # initialize from a levanter checkpoint - logger.info(f"Initializing model from checkpoint {self.config.initialize_from}") - match levanter.checkpoint.load_checkpoint( - model_shape, - None, - self.config.initialize_from, - axis_mapping=self.parameter_axis_mapping, - mesh=self.device_mesh, - ): - # new_model is probably only the trainable parameters, so we init the rest - case base_model, _, loaded_step: - logger.info(f"Initialized from step {loaded_step} of {self.config.initialize_from}") - old_model_init = model_init - - model_init = jax.tree_util.Partial(lambda m: eqx.combine(m, old_model_init()), base_model) - model, opt_state = named_jit(self._init_model_and_opt_state, self.parameter_axis_mapping)( - model_init - ) - - step = 0 - case None: - raise ValueError(f"Could not load model from checkpoint {self.config.initialize_from}") - else: - model, opt_state = named_jit(self._init_model_and_opt_state, self.parameter_axis_mapping)(model_init) - step = 0 + if do_load_checkpoint is not False: + try: + state = load_checkpoint(saveable_state_shape, checkpoint_path, axis_mapping=axis_mapping, mesh=mesh) + except FileNotFoundError: + if do_load_checkpoint: + raise + else: + state = None + + # if that fails, try to load just a model from a checkpoint for initialization + if state is None and initial_model_path is not None: + logger.info(f"Initializing from {initial_model_path}") + # todo: we are potentially holding two models in memory at once here, if we pass in a model + # instead of a model_init and we use initialize_from. We could avoid this by deleting + # any to-be-loaded parameters from the model before loading, but that's a bit more complicated + loaded_model = load_checkpoint( + saveable_state_shape.model, + initial_model_path, + axis_mapping=axis_mapping, + mesh=mesh, + subpath="model", + ) + + # we don't necessarily load the full model, so we need to combine it with the model init + model_init = jax.tree_util.Partial(lambda m, f: eqx.combine(m, f()), loaded_model, model_init) - return TrainerState(step, model, opt_state, training_key) + # now we initialize a fresh trainer state, possibly just to finish any missing fields + @named_jit(axis_resources=axis_mapping, donate_args=(True, True, True, False)) + def init_state(partial_state, model_init, training_key, is_trainable): + model = model_init() + fresh_state = self._initialize_state_from_scratch(model, training_key, is_trainable) + return eqx.combine(partial_state, fresh_state) + + state = init_state(state, model_init, training_key, self.is_trainable_param) + + return state def train_step(self, state: TrainerState[M], *batch: X, **batch_kwargs) -> StepInfo[M]: """ Performs a single training step. """ with capture_time() as step_time: - key, new_key = jax.random.split(state.training_key) - loss, new_model, new_optstate = self._train_step_fn( - state.model, state.opt_state, *batch, **batch_kwargs, key=key - ) + loss, new_state = self._jit_train_step_fn(state, *batch, **batch_kwargs) # force the loss so timing numbers are accurate. laziness isn't going to help here (i think?) loss = loss.item() # type: ignore - return StepInfo(TrainerState(state.step + 1, new_model, new_optstate, new_key), loss, step_time()) + return StepInfo(new_state, loss, step_time()) def training_steps( self, state: TrainerState[M], train_loader, run_hooks: bool = True @@ -355,7 +377,7 @@ def training_steps( """ iter_data = iter(train_loader) - while state.step < self.config.num_train_steps: + while state.step < self.num_train_steps: with capture_time() as loading_time: example = next(iter_data) @@ -391,7 +413,7 @@ def _add_default_hooks(self): self.add_hook(callbacks.pbar_logger(total=self.config.num_train_steps), every=1) self.add_hook(callbacks.log_step_info, every=1) # engine.add_hook(callbacks.log_memory_usage(), every=1) - checkpointer = self.config.checkpointer.create(self.run_id, self.is_trainable_param) + checkpointer = self.config.checkpointer.create(self.run_id) self.add_hook(checkpointer.on_step, every=1) # checkpointer manages its own frequency def add_eval_hook(self, eval_dataset, name: Optional[str] = None): @@ -440,34 +462,19 @@ def sharded_loader(self, dataset: ShardableDataset[X], batch_axis: Axis) -> Shar return ShardedBatchLoader(dataset, self.device_mesh, batch_axis, self.compute_axis_mapping) @cached_property - def _train_step_fn(self): - @named_jit( - axis_resources=self.parameter_axis_mapping, - out_axis_resources=self.parameter_axis_mapping, - donate_args=(True, True), - ) - def train_step(model, opt_state, *batch, **batch_kwargs): - model = inference_mode(model, False) + def _jit_train_step_fn(self): + return named_jit(self._train_step, axis_resources=self.parameter_axis_mapping, donate_args=(True,)) - # we do this so that we only take the gradients of the trainable parameters - trainable_model, rest_model = self.partition_trainable_params(model) + def _train_step(self, state: TrainerState, *batch, **batch_kwargs) -> tuple[Scalar, TrainerState]: + key, new_key = jax.random.split(state.training_key) + model = inference_mode(state.model, False) - def split_loss_fn(trainable_model, *batch, **batch_kwargs): - model = eqx.combine(trainable_model, rest_model) - return self.loss_fn(model, *batch, **batch_kwargs) + loss, grads = self._compute_gradients_microbatched(self.loss_fn, model, batch, **batch_kwargs, key=key) - loss, grads = self._compute_gradients_microbatched(split_loss_fn, trainable_model, batch, **batch_kwargs) + new_state = self._take_train_step(state, model, grads, *batch, **batch_kwargs, key=key) + new_state = dataclasses.replace(new_state, training_key=new_key) - updates, opt_state = self.optimizer.update(grads, opt_state, params=trainable_model) - - partial_fn = lambda model: self.loss_fn(model, *batch, **batch_kwargs) - - updates, opt_state = self.optimizer.update(grads, opt_state, params=trainable_model, obj_fn=partial_fn) - model = eqx.apply_updates(model, updates) - - return loss, model, opt_state - - return train_step + return loss, new_state def _compute_gradients_microbatched(self, loss_fn, model: M, batch, **batch_kwargs) -> tuple[Scalar, M]: grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=False) @@ -480,22 +487,59 @@ def _compute_gradients_microbatched(self, loss_fn, model: M, batch, **batch_kwar ) return grad_fn(model, *batch, **batch_kwargs) - def _init_model_and_opt_state(self, model_init): - model = model_init() - # only force trainable params to param precision. Other params are cast to compute precision - trainable, non_trainable = self.partition_trainable_params(model) - trainable = self.mp.cast_to_param(trainable) - non_trainable = self.mp.cast_to_compute(non_trainable) - model = eqx.combine(trainable, non_trainable) - opt_state = self.optimizer.init(trainable) - return model, opt_state - - def _init_non_trainable_params(self, model_init): - model = model_init() + def _take_train_step(self, state: S, model, grads, *batch, **batch_kwargs) -> S: + """ + Takes a training step. This is a separate method so that it can be overridden or used in a subclass. + """ + # only train on the trainable parameters. We're leaning on JAX to do dead code elimination for us + with hax.axis_mapping(self.parameter_axis_mapping): + train_grads = _partition_trainable_params(grads, state.is_trainable)[0] + trainable_model = _partition_trainable_params(model, state.is_trainable)[0] + updates, opt_state = self.optimizer.update(train_grads, state.opt_state, params=trainable_model) + + partial_fn = lambda model: self.loss_fn(model, *batch, **batch_kwargs) + + updates, opt_state = self.optimizer.update(grads, opt_state, params=trainable_model, obj_fn=partial_fn) + model = eqx.apply_updates(model, updates) + + return dataclasses.replace(state, _step=state._step + 1, model=model, opt_state=opt_state) + + def _initialize_state_from_scratch(self, model, training_key, is_trainable): # only force trainable params to param precision. Other params are cast to compute precision - trainable, non_trainable = self.partition_trainable_params(model) - non_trainable = self.mp.cast_to_compute(non_trainable) - return non_trainable + model = cast_params_by_trainability(model, self.mp, is_trainable) + opt_state = init_optimizer_for_trainables(self.optimizer, model, is_trainable) + + return TrainerState(0, model, opt_state, training_key, is_trainable) + + +def init_optimizer_for_trainables(optimizer, model, is_trainable): + trainable, _ = _partition_trainable_params(model, is_trainable) + opt_state = optimizer.init(trainable) + return opt_state + + +def _make_saveable_trainer_state(trainer_state: S, is_trainable) -> S: + """ + Returns the shape of the trainer state that we save to a checkpoint. This is used to load a checkpoint. + You can override if you really need custom checkpointing logic. By default everything in the trainer state + is saved (except for non-trainable model parameters) + """ + saveable_model = eqx.filter(trainer_state.model, is_trainable) + saveable_state = dataclasses.replace(trainer_state, model=saveable_model) + return saveable_state + + +def cast_params_by_trainability(model, mp, is_trainable): + """ + Casts the parameters of a model to the appropriate precision based on the is_trainable filter spec. + Trainable parameters are cast to param precision, non-trainable parameters are cast to compute precision. + """ + + trainable, non_trainable = _partition_trainable_params(model, is_trainable) + trainable = mp.cast_to_param(trainable) + non_trainable = mp.cast_to_compute(non_trainable) + model = eqx.combine(trainable, non_trainable) + return model def trainable_params_only(self, model: M) -> M: """ @@ -784,3 +828,32 @@ def initialize(config: TrainerConfig | AllConfig): def _params_only(t): return eqx.filter(t, is_inexact_arrayish) + + +def _partition_trainable_params(model, filter): + """ + Partitions the model into trainable and non-trainable parameters. This is used internally + for the gradient calculation and checkpointing, but you can also use it to filter out params for logging + or something. + + Returns: + trainable, non-trainable + """ + + def trainable_and_diffable(pred): + if callable(pred): + return lambda x: pred(x) and is_inexact_arrayish(x) + elif pred is True: + return is_inexact_arrayish + else: + return pred + + combined_mask = jax.tree_util.tree_map(trainable_and_diffable, filter) + return eqx.partition(model, combined_mask) + + +def _ensure_scalar(x: hax.types.Scalar | hax.NamedArray) -> hax.types.Scalar: + if isinstance(x, hax.NamedArray): + return x.scalar() + else: + return x diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index c22525fd6..db54b2569 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -1,3 +1,4 @@ +import dataclasses import datetime import pathlib import tempfile @@ -26,10 +27,11 @@ def _dummy_step_info(step): return StepInfo( state=TrainerState( # + 1 b/c step here is next step - step=step + 1, + _step=step + 1, model=None, opt_state=(), training_key=(), + is_trainable=True, ), loss=0.0, step_duration=0.0, @@ -139,42 +141,41 @@ def advance_time(delta_seconds): assert _get_checkpoint_steps(tmpdir) == [2, 4, 6, 8, 10, 15, 20, 30, 40, 49] # 49 is last temporary checkpoint +def _make_state(step, key): + model = MLP(in_size=2, out_size=1, width_size=2, depth=3, key=key) + optim = optax.adam(1e-4) + opt_state = optim.init(arrays_only(model)) + + return TrainerState(step, model, opt_state, key, True) + + def test_checkpoint_simple(): key0 = jax.random.PRNGKey(0) key1 = jax.random.PRNGKey(1) - def make_state(key): - model = MLP(in_size=2, out_size=1, width_size=2, depth=3, key=key) - optim = optax.adam(1e-4) - opt_state = optim.init(arrays_only(model)) - - return model, opt_state, key - - initial_model, initial_opt_state, initial_key = make_state(key0) - rep_model, rep_state, rep_key = make_state(key1) + initial_state = _make_state(10, key0) + rep_state = _make_state(2, key1) - assert_trees_not_close(initial_model, rep_model) + assert_trees_not_close(initial_state.model, rep_state.model) with tempfile.TemporaryDirectory() as tmpdir: save_checkpoint( - initial_model, - (initial_opt_state, initial_key), - step=10, + initial_state, + step=initial_state.step, checkpoint_path=tmpdir, ) - restored_model, (restored_optstate, rkey), step = load_checkpoint( - rep_model, - (rep_state, rep_key), + restored_state = load_checkpoint( + rep_state, checkpoint_path=tmpdir, discover_latest=False, ) assert_trees_all_close( - jax.tree_util.tree_leaves(arrays_only(restored_model)), - jax.tree_util.tree_leaves(arrays_only(initial_model)), + jax.tree_util.tree_leaves(arrays_only(restored_state.model)), + jax.tree_util.tree_leaves(arrays_only(initial_state.model)), ) - assert all(np.isclose(rkey, initial_key)) - assert step == 10 + assert all(np.isclose(restored_state.training_key, initial_state.training_key)) + assert restored_state.step == initial_state.step def test_checkpoint_steps(): @@ -183,13 +184,7 @@ def test_checkpoint_steps(): optim = optax.adam(1e-4) - def make_state(key): - model = MLP(in_size=2, out_size=1, width_size=2, depth=3, key=key) - opt_state = optim.init(arrays_only(model)) - - return model, opt_state, key - - initial_model, initial_opt_state, initial_key = make_state(key0) + initial_state = _make_state(10, key0) data = jax.random.uniform(key0, (2, 2)) @eqx.filter_grad @@ -197,41 +192,33 @@ def loss_fn(model, data): m = jax.vmap(model) return jnp.mean(jnp.square(m(data))) - model, state = initial_model, initial_opt_state + state = initial_state for i in range(3): - grad = loss_fn(model, data) - updates, state = optim.update(grad, state) - model = eqx.apply_updates(model, updates) + grad = loss_fn(state.model, data) + updates, new_state = optim.update(grad, state.opt_state) + model = eqx.apply_updates(state.model, updates) + state = dataclasses.replace(state, _step=state.step + 1, model=model, opt_state=new_state) - assert_trees_not_close(model, initial_model) - assert_trees_not_close(state, initial_opt_state) + assert_trees_not_close(state, initial_state) - rep_model, rep_state, rep_key = make_state(key1) - assert_trees_not_close(model, rep_model) + rep_state = _make_state(42, key1) assert_trees_not_close(state, rep_state) with tempfile.TemporaryDirectory() as tmpdir: - save_checkpoint(model, state, step=3, checkpoint_path=tmpdir) - restored_model, restored_optstate, step = load_checkpoint( - rep_model, rep_state, checkpoint_path=tmpdir, discover_latest=False - ) + save_checkpoint(state, step=3, checkpoint_path=tmpdir) + restored_state = load_checkpoint(rep_state, checkpoint_path=tmpdir, discover_latest=False) assert_trees_all_close( - jax.tree_util.tree_leaves(arrays_only(restored_model)), - jax.tree_util.tree_leaves(arrays_only(model)), - ) - assert_trees_all_close( - jax.tree_util.tree_leaves(arrays_only(restored_optstate)), + jax.tree_util.tree_leaves(arrays_only(restored_state)), jax.tree_util.tree_leaves(arrays_only(state)), ) - assert step == 3 def test_checkpoint_discovery(): with tempfile.TemporaryDirectory() as tempdir: - save_checkpoint(model=1, training_state=2, step=10, checkpoint_path=f"{tempdir}/step-10") - save_checkpoint(model=3, training_state=4, step=20, checkpoint_path=f"{tempdir}/step-20") - save_checkpoint(model=5, training_state=6, step=30, checkpoint_path=f"{tempdir}/step-30") + save_checkpoint(dict(model=1, training_state=2), step=10, checkpoint_path=f"{tempdir}/step-10") + save_checkpoint(dict(model=3, training_state=4), step=20, checkpoint_path=f"{tempdir}/step-20") + save_checkpoint(dict(model=5, training_state=6), step=30, checkpoint_path=f"{tempdir}/step-30") latest = discover_latest_checkpoint(tempdir) assert latest == f"{tempdir}/step-30" diff --git a/tests/test_eval_lm.py b/tests/test_eval_lm.py index 178069f26..a6bf3c8d9 100644 --- a/tests/test_eval_lm.py +++ b/tests/test_eval_lm.py @@ -13,6 +13,7 @@ from levanter.distributed import RayConfig from levanter.models.gpt2 import Gpt2LMHeadModel from levanter.tracker.wandb import WandbConfig +from levanter.trainer import TrainerState from levanter.utils.py_utils import logical_cpu_core_count @@ -43,7 +44,9 @@ def test_eval_lm(): Vocab = haliax.Axis("vocab", len(tok)) model = Gpt2LMHeadModel.init(Vocab, model_config, key=jax.random.PRNGKey(0)) - save_checkpoint(model, None, 0, f"{f}/ckpt") + state = TrainerState(0, model, model, jax.random.PRNGKey(0), True) + + save_checkpoint(state, 0, f"{f}/ckpt") config = eval_lm.EvalLmConfig( data=data_config, diff --git a/tests/test_export_to_hf.py b/tests/test_export_to_hf.py index 3ce092789..ed6a0d4c0 100644 --- a/tests/test_export_to_hf.py +++ b/tests/test_export_to_hf.py @@ -34,7 +34,7 @@ def test_export_lm_to_hf(): # in our trainer, we only export the trainable params trainable, non_trainable = eqx.partition(model, is_inexact_arrayish) - save_checkpoint(trainable, None, 0, f"{tmpdir}/ckpt") + save_checkpoint({"model": trainable}, 0, f"{tmpdir}/ckpt") try: config = export_lm_to_hf.ConvertLmConfig( diff --git a/tests/test_tensorstore_serialization.py b/tests/test_tensorstore_serialization.py index 398bad1f0..caf8a365d 100644 --- a/tests/test_tensorstore_serialization.py +++ b/tests/test_tensorstore_serialization.py @@ -1,10 +1,12 @@ from tempfile import TemporaryDirectory +from typing import Any import equinox as eqx import jax import jax.numpy as jnp import numpy as np import optax +import pytest from chex import assert_trees_all_close import haliax as hax @@ -127,3 +129,26 @@ def make_state(key): jax.tree_util.tree_leaves(arrays_only(restored_model)), jax.tree_util.tree_leaves(arrays_only(initial_model)), ) + + +def test_tensorstore_ok_with_nones(): + A = hax.Axis("A", 10) + + class MyModule(eqx.Module): + a: Any + b: Any + + m = MyModule(a=None, b=hax.zeros(A)) + m2 = MyModule(a=None, b=hax.ones(A)) + + with TemporaryDirectory() as tmpdir: + tree_serialize_leaves_tensorstore(tmpdir, m) + m3 = tree_deserialize_leaves_tensorstore(tmpdir, m2) + assert m3.a is None + assert hax.all(m3.b == hax.zeros(A)) + + m3 = MyModule(a=hax.zeros(A), b=hax.ones(A)) + with TemporaryDirectory() as tmpdir: + tree_serialize_leaves_tensorstore(tmpdir, m2) + with pytest.raises(ValueError): + tree_deserialize_leaves_tensorstore(tmpdir, m3) diff --git a/tests/test_viz_lm.py b/tests/test_viz_lm.py index 29d8f943c..71d117055 100644 --- a/tests/test_viz_lm.py +++ b/tests/test_viz_lm.py @@ -47,7 +47,7 @@ def test_viz_lm(): Vocab = haliax.Axis("vocab", len(tok)) model = Gpt2LMHeadModel.init(Vocab, model_config, key=jax.random.PRNGKey(0)) - save_checkpoint(model, None, 0, f"{f}/ckpt") + save_checkpoint({"model": model}, 0, f"{f}/ckpt") config = viz_logprobs.VizGpt2Config( data=data_config,