Skip to content

Commit

Permalink
Use JAX's serialization manager (#829)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh authored Dec 3, 2024
1 parent 578fb5b commit beed4de
Show file tree
Hide file tree
Showing 6 changed files with 202 additions and 210 deletions.
2 changes: 1 addition & 1 deletion src/levanter/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def on_step(self, info, force: bool = False):
if not force:
return # don't save checkpoint at step 0 unless forced

if step == self._last_save_step:
if step == self._last_save_step and not force:
# we've already saved a checkpoint at this step
return

Expand Down
190 changes: 88 additions & 102 deletions src/levanter/tensorstore_serialization.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
# References:
# * Orbax: https://github.com/google/orbax/blob/11d2934ecfff77e86b5e07d0fef02b67eff4511b/orbax/checkpoint/pytree_checkpoint_handler.py#L312
import asyncio
import functools
import logging
import os
from dataclasses import dataclass
from functools import partial
from typing import Callable, Optional
from typing import Any, Callable, Optional

import equinox
import jax
import jax.experimental.array_serialization.serialization as array_ser
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import tensorstore
from jax.sharding import Mesh
from tensorstore import TensorStore
from jax.sharding import Mesh, Sharding
from jaxtyping import PyTree

import haliax as hax
import haliax.tree_util as htu
from haliax.jax_utils import is_jax_array_like
from haliax.partitioning import ResourceMapping
from haliax.util import is_named_array

Expand All @@ -45,15 +43,23 @@ def tree_serialize_leaves_tensorstore(
else:
manager_was_none = False

leaf_key_paths = jax_utils.leaf_key_paths(pytree, is_leaf=_is_named_or_none)
leaf_key_paths = jax_utils.leaf_key_paths(pytree, is_leaf=is_named_array)
assert len(jax.tree.leaves(leaf_key_paths, is_leaf=is_named_array)) == len(
jax.tree.leaves(pytree, is_leaf=is_named_array)
)

def path_from_key_path(key_path):
return os.path.join(checkpoint_dir, *key_path.split("."))
paths = _fs_paths_from_key_paths(checkpoint_dir, leaf_key_paths)

paths = jtu.tree_map(path_from_key_path, leaf_key_paths, is_leaf=lambda x: x is None)
paths = jtu.tree_leaves(paths, is_leaf=lambda x: x is None)
leaves = jtu.tree_leaves(pytree, is_leaf=lambda x: x is None)
assert len(leaves) == len(paths)
# make a dataclass since tuples are pytrees
@dataclass
class Pair:
path: str
leaf: Any

zipped = jax.tree.map(lambda x, y: Pair(x, y), paths, pytree, is_leaf=lambda x: x is None)
paired_leaves = jax.tree.leaves(zipped)
paths = [p.path for p in paired_leaves]
leaves = [p.leaf.array if is_named_array(p.leaf) else p.leaf for p in paired_leaves]

# ok, not all of these are arrays, but we'll deal with that in the async function
def _ensure_is_array(x):
Expand All @@ -79,113 +85,93 @@ def _ensure_is_array(x):
manager.wait_until_finished()


def _tensorstore_spec_for(checkpoint_dir, key_path: str):
checkpoint_path = os.path.join(checkpoint_dir, *key_path.split("."))
ts_spec = array_ser.get_tensorstore_spec(checkpoint_path)
return ts_spec

def _fs_paths_from_key_paths(checkpoint_dir, leaf_key_paths):
def path_from_key_path(key_path):
return os.path.join(checkpoint_dir, *key_path.split("."))

async def _serialize_one_leaf(x, spec):
if isinstance(x, hax.NamedArray):
# we don't need to do anything special for named arrays to serialize, though we will for deserialization.
return await _serialize_one_leaf(x.array, spec)
elif isinstance(x, jax.Array):
if not x.is_fully_addressable:
return await array_ser.async_serialize(x, spec)
else:
return await save_array_to_tensorstore(x, spec)
elif isinstance(x, (bool, float, complex, int)):
return await save_array_to_tensorstore(np.array(x), spec)
elif x is None:
return
elif isinstance(x, jnp.ndarray):
return await save_array_to_tensorstore(x, spec)
elif isinstance(x, np.ndarray):
return await save_array_to_tensorstore(x, spec)
paths = jtu.tree_map(path_from_key_path, leaf_key_paths)
return paths


def _sharding_from_leaf(leaf, axis_mapping, mesh) -> Optional[jax.sharding.Sharding]:
if is_named_array(leaf):
if leaf.array is None:
return None
return hax.partitioning.sharding_for_axis(leaf.axes, axis_mapping, mesh)
elif hasattr(leaf, "sharding") and getattr(leaf, "sharding") is not None:
return leaf.sharding
elif is_jax_array_like(leaf):
return _fully_replicated_sharding(mesh)
elif isinstance(leaf, (bool, float, complex, int, np.ndarray)):
return _fully_replicated_sharding(mesh)
else:
raise TypeError(f"Can't serialize {type(x)}")


async def save_array_to_tensorstore(x, spec):
if jax.process_index() == 0:
if x.dtype == jnp.bfloat16:
# Tensorstore uses 'bfloat16', not '<V2'.
dtype = "bfloat16"
else:
dtype = np.dtype(x.dtype).str
t = await tensorstore.open(
tensorstore.Spec(spec), create=True, shape=x.shape, dtype=dtype, context=array_ser.TS_CONTEXT
)

await t.write(x)


async def load_array_from_tensorstore(spec):
t: TensorStore = await tensorstore.open(tensorstore.Spec(spec), context=array_ser.TS_CONTEXT)
return await t.read()


async def _deserialize_one_leaf(like, spec, axis_mapping, mesh):
if is_named_array(like):
return await _deserialize_named_array(like, spec, axis_mapping, mesh)
elif isinstance(like, jax.Array):
if not like.is_fully_addressable:
return await array_ser.async_deserialize(like.sharding, spec, global_shape=like.shape, dtype=like.dtype)
else:
return await load_array_from_tensorstore(spec)
elif isinstance(like, (bool, float, complex, int)):
arr = await load_array_from_tensorstore(spec)
return arr.item()
elif like is None:
logger.warning(f"Unknown leaf type {type(leaf)}")
return None
elif isinstance(like, jnp.ndarray) or isinstance(like, np.ndarray) or isinstance(like, jax.ShapeDtypeStruct):
return await load_array_from_tensorstore(spec)
elif callable(like):
return like
else:
raise TypeError(f"Can't deserialize {type(like)}")


async def _deserialize_named_array(like, spec, axis_mapping, mesh):
# the main thing we're worried about is deserialized NamedArrays that are not yet arrays but are ShapedDtypeStructs.
# These don't (currently) have sharding info, but we can infer it from the axes
if isinstance(like.array, jax.ShapeDtypeStruct):
sharding = hax.partitioning.sharding_for_axis(like.axes, axis_mapping, mesh)
array = await array_ser.async_deserialize(sharding, spec, global_shape=like.array.shape, dtype=like.dtype)
assert sharding.is_equivalent_to(array.sharding, len(like.array.shape))
return hax.NamedArray(array, like.axes)
else:
array = await _deserialize_one_leaf(like.array, spec, axis_mapping, mesh)
return hax.NamedArray(array, like.axes)
def _fully_replicated_sharding(mesh):
return hax.partitioning.sharding_for_axis((), {}, mesh)


def tree_deserialize_leaves_tensorstore(
checkpoint_dir, pytree, axis_mapping: Optional[ResourceMapping] = None, mesh: Optional[Mesh] = None
checkpoint_dir,
pytree,
axis_mapping: Optional[ResourceMapping] = None,
mesh: Optional[Mesh] = None,
manager: Optional[array_ser.GlobalAsyncCheckpointManager] = None,
):
"""
Deserializes a PyTree of Arrays and NamedArrays from a Tensorstore checkpoint, returning a pytree with the same shape
as the one provided. This method is capable of deserializing NamedArrays that are the result of an eval_shape call
(i.e. they are not yet arrays but are ShapedDtypeStructs), provided you pass in the axis_mapping and mesh (or
they are available by context)
:param checkpoint_dir: the directory containing the tensorstore checkpoint, can be a local path or a GCS path
:param pytree: the exemplar pytree
:param axis_mapping: optional, the axis mapping for the NamedArrays (if they are not yet arrays)
:param mesh: optional, the mesh for the NamedArrays (if they are not yet arrays)
Args:
checkpoint_dir: the directory containing the tensorstore checkpoint, can be a local path or a GCS path
pytree: the exemplar pytree
axis_mapping: optional, the axis mapping for the NamedArrays (if they are not yet arrays)
mesh: optional, the mesh for the NamedArrays (if they are not yet arrays)
manager: optional, the checkpoint manager to use. If not provided, a new one will be created
:return: a pytree with the same shape as the exemplar pytree, but with the arrays deserialized from the checkpoint
Returns:
A pytree with the same shape as the exemplar pytree, but with the arrays deserialized from the checkpoint
"""
if manager is None:
manager = array_ser.GlobalAsyncCheckpointManager()

shardings: PyTree[Optional[Sharding]] = jtu.tree_map(
partial(_sharding_from_leaf, axis_mapping=axis_mapping, mesh=mesh), pytree, is_leaf=is_named_array
)

# TODO: support ShapeDtypeStructs that are not NamedArrays
leaf_key_paths = jax_utils.leaf_key_paths(pytree, is_leaf=is_named_array)
specs = htu.tree_map(partial(_tensorstore_spec_for, checkpoint_dir), leaf_key_paths)
leaf_key_paths = jax_utils.leaf_key_paths(shardings, is_leaf=is_named_array)
paths = _fs_paths_from_key_paths(checkpoint_dir, leaf_key_paths)
paths = jtu.tree_leaves(paths, is_leaf=lambda x: x is None)

shardings_leaves, shardings_structure = jtu.tree_flatten(shardings, is_leaf=_is_named_or_none)

deser_partial = functools.partial(_deserialize_one_leaf, axis_mapping=axis_mapping, mesh=mesh)
assert len(shardings_leaves) == len(paths)

futures = jtu.tree_map(deser_partial, pytree, specs, is_leaf=is_named_array)
leaves, structure = jtu.tree_flatten(futures, is_leaf=is_named_array)
# ok, so, jax really doesn't want any Nones in the leaves here, so we need to temporarily partition the pytree
real_indices = [i for i, x in enumerate(shardings_leaves) if x is not None]
real_leaves = [x for x in shardings_leaves if x is not None]
real_paths = [paths[i] for i in real_indices]

async def _do_deserialize():
values = await asyncio.gather(*leaves)
return jtu.tree_unflatten(structure, values)
deser_leaves = manager.deserialize_with_paths(shardings=real_leaves, paths=real_paths)
# now we need to recreate the original structure

out_leaves = [None] * len(shardings_leaves)
for i, x in zip(real_indices, deser_leaves):
out_leaves[i] = x

deser_arrays = jtu.tree_unflatten(shardings_structure, out_leaves)

# deser_arrays only has arrays, but we need named arrays for at least some.
# The original pytree has the structure we want, so we'll use that to rebuild the named arrays
def _rebuild_named_array(like, array):
if is_named_array(like):
return hax.NamedArray(array, like.axes)
else:
return array

return asyncio.run(_do_deserialize())
return jtu.tree_map(_rebuild_named_array, pytree, deser_arrays, is_leaf=_is_named_or_none)
32 changes: 19 additions & 13 deletions src/levanter/utils/jax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,44 +152,50 @@ def leaf_key_paths(
x, prefix=join_key(prefix, p), is_leaf=is_leaf, use_state_dict_keys=use_state_dict_keys
)

out: PyTree[str]

if is_leaf is not None and is_leaf(pytree):
return prefix
out = prefix
elif pytree is None:
out = None
elif isinstance(pytree, dict):
return {k: rec(v, k) for k, v in pytree.items()}
out = {k: rec(v, k) for k, v in pytree.items()}
elif _isnamedtupleinstance(pytree):
d = {k: rec(v, k) for k, v in pytree._asdict().items()}
return pytree.__class__(**d)
out = pytree.__class__(**d)
elif isinstance(pytree, list):
return [rec(v, str(i)) for i, v in enumerate(pytree)]
out = [rec(v, str(i)) for i, v in enumerate(pytree)]
elif isinstance(pytree, tuple):
return tuple(rec(v, str(i)) for i, v in enumerate(pytree))
out = tuple(rec(v, str(i)) for i, v in enumerate(pytree))
elif isinstance(pytree, eqx.Module):
names = []
rec_values = []
for field in fields(pytree):
if field.metadata.get("static", False):
continue
field_name = field.name
field = getattr(pytree, field_name)
field_value = getattr(pytree, field_name)
names.append(field_name)

if use_state_dict_keys and hasattr(pytree, "_state_dict_key_map"):
field_name = pytree._state_dict_key_map().get(field_name, field_name)

rec_value = rec(field, field_name)
rec_value = rec(field_value, field_name)
rec_values.append(rec_value)

_, tree_def = eqx.tree_flatten_one_level(pytree)
out = jax.tree_util.tree_unflatten(tree_def, rec_values)
return out
# this doesn't work reliably because tree_at doesn't like none values
# return eqx.tree_at(lambda m: [getattr(m, name) for name in names], pytree, rec_values, is_leaf=lambda x: x is None)
else:
leaves, treedef = jax.tree_util.tree_flatten(pytree, is_leaf=is_leaf)
if len(leaves) == 1:
return jax.tree_util.tree_unflatten(treedef, [f"{prefix}"])
if len(leaves) == 0:
out = None
elif len(leaves) == 1:
out = jax.tree_util.tree_unflatten(treedef, [f"{prefix}"])
else:
return jax.tree_util.tree_unflatten(treedef, [join_key(prefix, str(i)) for i in range(len(leaves))])
out = jax.tree_util.tree_unflatten(treedef, [join_key(prefix, str(i)) for i in range(len(leaves))])

# assert len(jax.tree.leaves(out, is_leaf=is_leaf)) == len(jax.tree.leaves(pytree, is_leaf=is_leaf)), (out, pytree)
return out


def join_key(prefix, k):
Expand Down
10 changes: 5 additions & 5 deletions tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from jax import ShapeDtypeStruct
from jax import numpy as jnp

import haliax
import haliax as hax
from haliax import Axis

from levanter.checkpoint import (
Expand Down Expand Up @@ -252,17 +252,17 @@ def test_load_from_checkpoint_or_initialize():
Out = Axis("out", 1)

def init_fn(key):
return haliax.nn.MLP.init(In, Out, 2, 1, key=key, use_bias=False, use_final_bias=False)
return hax.nn.MLP.init(In, Out, 2, 1, key=key, use_bias=False, use_final_bias=False)

k0 = jax.random.PRNGKey(0)
k1 = jax.random.PRNGKey(1)

model0 = eqx.filter_jit(init_fn)(k0)
model1 = eqx.filter_jit(init_fn)(k1)

is_checkpointed = jtu.tree_map(lambda _: False, model0)
is_checkpointed = hax.tree_util.tree_map(lambda _: False, model0)
is_checkpointed = eqx.tree_at(lambda t: t.layers[-1], is_checkpointed, replace=True)
is_checkpointed1 = jtu.tree_map(lambda _: False, model1)
is_checkpointed1 = hax.tree_util.tree_map(lambda _: False, model1)
is_checkpointed1 = eqx.tree_at(lambda t: t.layers[-1], is_checkpointed1, replace=True)

with jax.sharding.Mesh(jax.devices(), ("devices",)), tempfile.TemporaryDirectory() as tmpdir:
Expand Down Expand Up @@ -306,7 +306,7 @@ def test_load_from_checkpoint_or_initialize_works_if_file_not_found():
Out = Axis("out", 1)

def init_fn(key):
return haliax.nn.MLP.init(In, Out, 2, 3, key=key)
return hax.nn.MLP.init(In, Out, 2, 3, key=key)

k0 = jax.random.PRNGKey(0)
k1 = jax.random.PRNGKey(1)
Expand Down
8 changes: 1 addition & 7 deletions tests/test_export_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,7 @@ def test_export_lm_to_hf():
config = export_lm_to_hf.ConvertLmConfig(
checkpoint_path=f"{tmpdir}/ckpt",
output_dir=f"{tmpdir}/output",
model=export_lm_to_hf.Gpt2Config(
num_layers=2,
num_heads=2,
seq_len=64,
use_flash_attention=True,
hidden_dim=32,
),
model=model_config,
)
export_lm_to_hf.main(config)

Expand Down
Loading

0 comments on commit beed4de

Please sign in to comment.