Skip to content

Commit

Permalink
Refactor TrainerState to make it a module, make save_checkpoint a nic…
Browse files Browse the repository at this point in the history
…er function (#462)

* Trackers let us abstract out TB vs wandb

* missed a few spots

* remove old config

* wip

* missed some spots

* more missed spots
  • Loading branch information
dlwh authored Feb 12, 2024
1 parent 50f72d7 commit 8f9a3de
Show file tree
Hide file tree
Showing 13 changed files with 330 additions and 232 deletions.
120 changes: 72 additions & 48 deletions src/levanter/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import urllib.parse
from dataclasses import dataclass
from datetime import timedelta
from typing import Callable, List, Optional, Sequence, Tuple, TypeVar, Union
from typing import Callable, List, Optional, Sequence, TypeVar, Union

import equinox
import fsspec
Expand All @@ -28,8 +28,7 @@

PathLike = Union[str, pathlib.Path]

M = TypeVar("M")
S = TypeVar("S")
M = TypeVar("M", bound=PyTree)


@dataclass(frozen=True)
Expand Down Expand Up @@ -102,19 +101,16 @@ def __init__(

def load_checkpoint(
self,
model: M,
training_state: S,
state: M,
path: Optional[PathLike] = None,
*,
discover_latest: bool = True,
axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None,
mesh: Optional[haliax.partitioning.Mesh] = None,
) -> Optional[Tuple[M, S, int]]:
) -> Optional[M]:
if path is None:
path = self.base_path
return load_checkpoint(
model, training_state, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh
)
return load_checkpoint(state, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh)

def load_model(
self,
Expand All @@ -124,16 +120,17 @@ def load_model(
discover_latest: bool = True,
axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None,
mesh: Optional[haliax.partitioning.Mesh] = None,
) -> Optional[Tuple[M, int]]:
if path is None:
path = self.base_path
ckpt = load_checkpoint(
model, None, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh
) -> Optional[M]:
"""
Convenience method/holdover from previous API for loading checkpoints.
Loads just the model assuming the model is in the `model` subdir of the discovered checkpoint.
"""
ret_dict = self.load_checkpoint(
{"model": model}, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh
)
if ckpt is None:
if ret_dict is None:
return None
model, _, step = ckpt
return model, step
return ret_dict["model"]

def on_step(self, info, force: bool = False):
step = info.step
Expand Down Expand Up @@ -219,10 +216,9 @@ def _rm_checkpoint(self, checkpoint):
def save_checkpoint(self, info, destination: str):
path = os.path.join(self.base_path, destination)
logger.info(f"Saving checkpoint at step {info.step} to {path}")
model = equinox.filter(info.model, self.keep_params)
state = equinox.filter(info.state, info.state.is_trainable)
save_checkpoint(
model=model,
training_state=(info.opt_state, info.next_key),
state,
step=info.step,
checkpoint_path=path,
)
Expand All @@ -231,7 +227,7 @@ def save_checkpoint(self, info, destination: str):
logger.info(f"Saved checkpoint at step {info.step} to {path}. Save time is {self._last_save_time}")


def save_checkpoint(model, training_state, step: int, checkpoint_path: PathLike):
def save_checkpoint(tree: M, step: int, checkpoint_path: PathLike):
"""
Save a checkpoint to a given path using TensorStore. If exist_ok is True, the checkpoint
will be saved even if a checkpoint already exists at the given path.
Expand All @@ -249,10 +245,7 @@ def save_checkpoint(model, training_state, step: int, checkpoint_path: PathLike)
fs, plain_path = _get_fs_and_plain_path(checkpoint_path)
fs.makedirs(plain_path, exist_ok=True)

tree_serialize_leaves_tensorstore(os.path.join(checkpoint_path, "model"), model)
if training_state is not None:
tree_serialize_leaves_tensorstore(os.path.join(checkpoint_path, "training_state"), training_state)

tree_serialize_leaves_tensorstore(checkpoint_path, tree)
save_metadata(checkpoint_path, fs, step)

logger.info(f"Saved checkpoint for step {step}")
Expand All @@ -271,22 +264,30 @@ def save_metadata(checkpoint_path, fs, step):


def load_checkpoint(
model: M,
training_state: S,
tree: M,
checkpoint_path: PathLike,
*,
subpath: Optional[str] = None,
discover_latest=True,
axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None,
mesh: Optional[jax.sharding.Mesh] = None,
) -> Optional[Tuple[M, S, int]]:
) -> M:
"""
Load a checkpoint from a given path.
Returns the loaded model state, training state, and step. If discover_latest is True,
the latest checkpoint in the given path will be loaded. Otherwise, the checkpoint at
the given path will be loaded. If no checkpoint is found, returns None
Load a checkpoint from a given path. If discover_latest is True, then the latest checkpoint
in a subdirectory of the given path will be loaded. If subpath is not None, then the checkpoint
loads only that subpath of the checkpoint. This is useful for loading, e.g., just the model and not
the entire training state.
Args:
tree: an exemplar of the tree to load. Can be a PyTree[ShapeDTypeStruct] instead of a PyTree[Any]
checkpoint_path: the path to load the checkpoint from
subpath: the subpath to load from the checkpoint
discover_latest: whether to discover the latest checkpoint in the given path
axis_mapping: the axis mapping to use for loading the checkpoint
mesh: the mesh to use for loading the checkpoint
Returns:
the loaded checkpoint, with the same structure as the exemplar tree
If training_state is None, no training state will be loaded.
"""
fs: AbstractFileSystem
fs, _ = _get_fs_and_plain_path(checkpoint_path)
Expand All @@ -297,28 +298,52 @@ def load_checkpoint(
checkpoint_path = discover_latest_checkpoint(checkpoint_path) # type: ignore

if checkpoint_path is None or not fs.exists(checkpoint_path):
return None
raise FileNotFoundError(f"Could not find checkpoint at {checkpoint_path}")

logger.info(f"Loading checkpoint from {checkpoint_path}")
metadata = load_metadata(checkpoint_path, fs)

model = tree_deserialize_leaves_tensorstore(
os.path.join(checkpoint_path, "model"), model, axis_mapping=axis_mapping, mesh=mesh
)
if subpath:
checkpoint_path = os.path.join(checkpoint_path, subpath)

if training_state is None:
training_state = None
else:
training_state = tree_deserialize_leaves_tensorstore(
os.path.join(checkpoint_path, "training_state"), training_state, axis_mapping=axis_mapping, mesh=mesh
)
try:
tree = tree_deserialize_leaves_tensorstore(checkpoint_path, tree, axis_mapping=axis_mapping, mesh=mesh)
return tree
except: # noqa
from levanter.trainer import TrainerState

return model, training_state, metadata["step"]
if not isinstance(tree, TrainerState):
raise
else:
logger.warning("Attempting to load old-style checkpoint")
model, training_state = tree.model, (tree.opt_state, tree.training_key)

model = tree_deserialize_leaves_tensorstore(
os.path.join(checkpoint_path, "model"), model, axis_mapping=axis_mapping, mesh=mesh
)

if training_state is None:
opt_state = None
key = None
else:
training_state = tree_deserialize_leaves_tensorstore(
os.path.join(checkpoint_path, "training_state"),
training_state,
axis_mapping=axis_mapping,
mesh=mesh,
)
opt_state, key = training_state

# TODO: pretty sure this is right, but should verify
step = metadata["step"]
new_state = dataclasses.replace(
tree, _step=step + 1, model=model, opt_state=opt_state, training_key=key # type: ignore
)
return new_state


def load_metadata(checkpoint_path, fs=None):
if fs is None:
fs: AbstractFileSystem
fs, _, _ = fsspec.get_fs_token_paths(str(checkpoint_path))
with fs.open(os.path.join(checkpoint_path, "metadata.json")) as metadata_in:
metadata = json.load(metadata_in)
Expand Down Expand Up @@ -381,13 +406,12 @@ class CheckpointerConfig:
def expanded_path(self, run_id):
return os.path.expanduser(os.path.join(self.base_path, run_id))

def create(self, run_id, keep_params: PyTree[FilterSpec] = True) -> Checkpointer:
def create(self, run_id) -> Checkpointer:
keeps = [CheckpointInterval(**k) for k in self.keep]
return Checkpointer(
base_path=self.expanded_path(run_id),
save_interval=self.save_interval,
step_policies=keeps,
keep_params=keep_params,
)

def __post_init__(self):
Expand Down
11 changes: 0 additions & 11 deletions src/levanter/data/shard_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,17 +1231,6 @@ def priority_fn(shard_idx, chunk_idx):

ray.get(reader_actor.add_work_group.remote(work_item))

# reader = _alternating_shard_reader.remote(
# name,
# self_ref,
# writer,
# source,
# shard_group,
# priority_fn,
# processor_actor,
# processor.batch_size,
# rows_per_chunk,
# )
self._shard_readers.append(reader_actor)

def new_chunk(self, shard_name: str, *chunks: ChunkMetadata):
Expand Down
6 changes: 1 addition & 5 deletions src/levanter/main/eval_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,10 @@ def compute_loss(model: LmHeadModel, example: LmExample):
with use_cpu_device():
model = eqx.filter_eval_shape(config.model.build, Vocab, key=key)
# TODO: don't load the entire checkpoint into CPU memory when we only need our share of the model
ckpt = load_checkpoint(model, None, config.checkpoint_path)

assert ckpt is not None
model, _, _ = ckpt
model = load_checkpoint(model, config.checkpoint_path, subpath="model")

model = hax.shard_with_axis_mapping(model, parameter_axis_mapping)

# TODO: switch to throwing instead of returning None
loss = callbacks.eval_loss_loop(compute_loss, model, eval_loader, max_batches=total)

del model
Expand Down
5 changes: 2 additions & 3 deletions src/levanter/main/export_lm_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,9 @@ def main(config: ConvertLmConfig):
model: LmHeadModel = eqx.filter_eval_shape(config.model.build, Vocab, key=key)
trainable, non_trainable = eqx.partition(model, is_inexact_arrayish)
# TODO: don't load the entire checkpoint into CPU memory when we only need our share of the model
ckpt = load_checkpoint(trainable, None, config.checkpoint_path)
trainable = load_checkpoint(trainable, config.checkpoint_path, subpath="model")

assert ckpt is not None
trainable, _, _ = ckpt
assert trainable is not None
model = eqx.combine(trainable, non_trainable)

if config.override_vocab_size:
Expand Down
7 changes: 5 additions & 2 deletions src/levanter/main/lora_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,10 @@ def compute_loss(model, example: LmExample, key=None):
state = trainer.initial_state(training_key, model=model)

all_param_count = parameter_count(state.model)
just_lora_params = parameter_count(trainer.trainable_params_only(state.model))
# TODO: remove this once we put this in trainer itself
just_lora_params = parameter_count(
levanter.trainer._partition_trainable_params(state.model, lora_param_filter)
)

levanter.tracker.log_summary(
{
Expand Down Expand Up @@ -140,7 +143,7 @@ def compute_loss(model, example: LmExample, key=None):
# TODO: implement iter_data.seek(resume_step +1)
import tqdm

for _ in tqdm.tqdm(range(state.step + 1), desc="seeking data for resume"):
for _ in tqdm.tqdm(range(state.step), desc="seeking data for resume"):
next(iter_data)

## OK, actually run training!
Expand Down
2 changes: 1 addition & 1 deletion src/levanter/main/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def compute_log_probs(model, example: LmExample):
# TODO: implement iter_data.seek(resume_step +1)
import tqdm

for _ in tqdm.tqdm(range(state.step + 1), desc="seeking data for resume"):
for _ in tqdm.tqdm(range(state.step), desc="seeking data for resume"):
next(train_loader)

## OK, actually run training!
Expand Down
5 changes: 2 additions & 3 deletions src/levanter/main/viz_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,9 @@ def compute_log_probs(model: LmHeadModel, example: LmExample):
with use_cpu_device():
model = eqx.filter_eval_shape(config.model.build, Vocab, key=key)
# TODO: don't load the entire checkpoint into CPU memory when we only need our share of the model
ckpt = load_checkpoint(model, None, config.checkpoint_path)
model = load_checkpoint(model, config.checkpoint_path, subpath="model")

assert ckpt is not None
model, _, _ = ckpt
assert model is not None

model = hax.shard_with_axis_mapping(model, parameter_axis_mapping)

Expand Down
Loading

0 comments on commit 8f9a3de

Please sign in to comment.