diff --git a/docs/state-dict.md b/docs/state-dict.md index 7cdc54e..6bbc7d8 100644 --- a/docs/state-dict.md +++ b/docs/state-dict.md @@ -1,24 +1,125 @@ -# Serialization +# State Dicts and Serialization -Haliax supports serialization of modules (including any [equinox.Module][]) to and from PyTorch-compatible -state dicts using the [safetensors](https://github.com/huggingface/safetensors) library. For details on -how state dicts work in PyTorch, see the [PyTorch documentation](https://pytorch.org/docs/stable/notes/serialization.html#saving-and-loading-torch-nn-modules). +Haliax follows the Equinox convention of mostly working in terms of module trees, similar to PyTorch. However, +sometimes we would rather the weights all be in one easy to control dictionary. +For example, when saving a model to disk, you typically want to save the weights and biases in a single file (especially +for compatibility with PyTorch and other ecosystems). +Similarly, someone coming from Flax might be more comfortable working with a parameter dictionary rather +than a bunch of PyTree-wrangling. + +This is where state dicts come in. A state dict is a Python dictionary that maps string keys to tensors. It is used to store the parameters of a model (though typically not the model's structure or hyperparameters). The keys are typically the names of the model's parameters, arranged as `.`-separated paths. For example, a model with a `conv1` layer might have a state dict with keys like `conv1.weight` and `conv1.bias`. Sequences of modules (e.g., for lists of layers) are -serialize with keys like `layer.0.weight`, `layer.1.weight`, etc. +serialized with keys like `layer.0.weight`, `layer.1.weight`, etc. + +The values in a Haliax state dict are typically JAX or numpy arrays, and not NamedArrays. + +## Basic State Dicts + +To create a state dict from a module, use the [haliax.state_dict.to_state_dict][] function. This function takes a module +and returns a state dict: + +```python +import haliax as hax +import jax.random as jrandom + +# Create a module +Heads = hax.Axis("Heads", 8) +Dim = hax.Axis("Dim", 16) +Out = hax.Axis("Out", 5) + +module = hax.nn.Linear.init(In=(Heads, Dim), Out=Out, key=jrandom.PRNGKey(0)) + +# Serialize the module to a state dict +state_dict = hax.state_dict.to_state_dict(module) +``` + +You can manipulate the state dict as you would any other Python dictionary. Note that the arrays are JAX arrays, not +NamedArrays or Numpy arrays. In particular, with `to_state_dict`, the arrays still preserve any sharding or vmapping. +This makes it a great choice for using inside JIT. + +If you want a CPU-only state dict, you can use the [haliax.state_dict.to_numpy_state_dict][] function: + +```python +# Serialize the module to a state dict + +state_dict = hax.state_dict.to_numpy_state_dict(module) +``` + +To load a state dict back into a module, use the [haliax.state_dict.from_state_dict][] function. This function +requires a "template" module that has the same structure as the module that was serialized to the state dict: + +```python +# Load the state dict into a module +module = hax.state_dict.from_state_dict(module, state_dict) +``` + +One trick is that you can use the `init` method of a module inside of [equinox.filter_eval_shape][] to create +an abstract version of the module that can be used as a template for loading the state dict. This is useful if you +want to avoid allocating a bunch of big arrays just to load the state dict. + +```python +import equinox as eqx + +module_template = eqx.filter_eval_shape(Linear.init, In=(Heads, Dim), Out=Out) +module = hax.state_dict.from_state_dict(module_template, state_dict) +``` + +### Saving a State Dict + +!!! warning + The default Haliax state dicts are not in general compatible with PyTorch. If you want to load a Haliax state dict + you will need to convert it to a PyTorch-compatible state dict first and use the + [safetensors](https://github.com/huggingface/safetensors) library to load it into PyTorch. + See the [PyTorch-Compatible State Dicts](#pytorch-compatible-state-dicts) section for how to do this. + +To save the state dict to a file, use the [haliax.state_dict.save_state_dict][] function together with the +[haliax.state_dict.to_numpy_state_dict][] function: + +```python +# Save the state dict to a file +hax.state_dict.save_state_dict(state_dict, 'state_dict.safetensors') +``` + +Likewise, you can load a state dict from a file using the [haliax.state_dict.load_state_dict][] function: + +```python +# Load the state dict from a file +state_dict = hax.state_dict.load_state_dict('state_dict.safetensors') +``` + +#### Things to know + +Haliax supports serialization of modules (including any [equinox.Module][]) to and from PyTorch-compatible +state dicts using the [safetensors](https://github.com/huggingface/safetensors) library. For details on +how state dicts work in PyTorch, see the [PyTorch documentation](https://pytorch.org/docs/stable/notes/serialization.html#saving-and-loading-torch-nn-modules). (Levanter has JAX-native TensorStore based +serialization that we should upstream here.) Haliax uses the [safetensors](https://github.com/huggingface/safetensors) library to serialize state dicts. This -library is a safer, more portable format developed by Hugging Face. Serializing a native PyTorch state dict requires +library is a safer, more portable format developed by Hugging Face. (Serializing a native PyTorch state dict requires PyTorch itself, and we want to avoid that dependency. Also, PyTorch uses pickles, which are in general not -safe to deserialize from untrusted sources. +safe to deserialize from untrusted sources.) This does mean that you can't directly load a Haliax state dict into PyTorch, but safetensors is lightweight and easy to use. Hugging Face natively supports it in their libraries. +(See the [PyTorch-Compatible State Dicts](#pytorch-compatible-state-dicts) section for more details +on how to convert a Haliax state dict to a PyTorch-compatible state dict.) + -## Saving a State Dict +## Pytorch-Compatible State Dicts + +Haliax provides a way to serialize a module to a PyTorch-compatible state dict. This is useful if you want to +load the weights of a Haliax module into a PyTorch model or vice versa. The [haliax.state_dict.to_torch_compatible_state_dict][] +and [haliax.state_dict.from_torch_compatible_state_dict][] functions allow you to convert a Haliax state dict to and from +a PyTorch-compatible state dict. + +Note that these methods behave a bit differently from the basic state dict methods. See +the [Explanation](#explanation) section for more details. + +### Saving a State Dict To serialize a module to a Pytorch-compatible state dict, use the [haliax.state_dict.to_torch_compatible_state_dict][] function. This function takes a module and returns a state dict. To save the state dict to a file, use the @@ -57,13 +158,14 @@ model = torch.nn.Linear(10, 5) state_dict = load_model(model, 'state_dict.safetensors') ``` -## Loading a State Dict +### Loading a State Dict Similarly, you can load a state dict from a file using the [haliax.state_dict.load_state_dict][] function. This function reads a state dict from a file in safetensors format and returns a dictionary. To load the state dict into a module, use the [haliax.state_dict.from_torch_compatible_state_dict][] function. ```python +import haliax.state_dict import haliax as hax import jax.random as jrandom @@ -77,13 +179,79 @@ module = hax.nn.Linear.init(In=(Heads, Dim), Out=Out, key=jrandom.PRNGKey(0)) state_dict = hax.state_dict.load_state_dict('state_dict.safetensors') # this will unflatten the state dict and load it into the module -module = hax.state_dict.from_torch_compatible_state_dict(module, state_dict) +module = haliax.state_dict.from_torch_compatible_state_dict(module, state_dict) ``` -The `from_torch_compatible_state_dict` function will unflatten the state dict and load it into the module. Note +The `from_torch_compatible_state_dict` function will prepare the state dict and load it into the module. Note that the module must have the same structure as the module that was serialized to the state dict. If the module structure has changed, you may need to manually update the state dict keys to match the new structure. +### Explanation + +By default, Haliax creates a state dict using key paths and arrays that mirror the module's structure. For example, a +Linear module with an `In` axis spec of `(Head, HeadDim)` and an `Out` axis spec of `Embed` will have a state dict +with a key `weight` that is a three-dimensional tensor with shape `(Embed, Head, HeadDim)`. Moreover, +instances of [haliax.nn.Stacked][] (i.e. our "scan-layers" module) will have a state dict with the vmapped +layers as a single module: + +```python +import haliax as hax +import haliax.nn as hnn +import jax.random as jrandom + +Heads = hax.Axis("Heads", 8) +Dim = hax.Axis("Dim", 16) +Out = hax.Axis("Out", 5) +Block = hax.Axis("Block", 3) + +keys = jrandom.split(jrandom.PRNGKey(0), Block.size) + +stacked_module = hnn.Stacked.init(Block, hnn.Linear)(In=(Heads, Dim), Out=Out, key=keys) + +state_dict = hax.state_dict.to_state_dict(stacked_module) + +for k, v in state_dict.items(): + print(k, v.shape) + +# Output: +# weight (3, 5, 8, 16) +# bias (3, 5) +``` + +PyTorch expects the weights of a linear layer to be a 2D tensor with shape `(out_features, in_features)` and the bias +to be a 1D tensor with shape `(out_features,)`. Moreover, it expects the layers of a stacked module to be unstacked +into a `torch.nn.Sequential`. To do this, we use the [haliax.state_dict.to_torch_compatible_state_dict][] function: + +```python +torch_state_dict = hax.state_dict.to_torch_compatible_state_dict(stacked_module) + +for k, v in torch_state_dict.items(): + print(k, v.shape) + +# Output: +# bias (3, 5) +# 0.weight (5, 128) +# 0.bias (5,) +# 1.weight (5, 128) +# 1.bias (5,) +# 2.weight (5, 128) +# 2.bias (5,) + +# save it +hax.state_dict.save_state_dict(torch_state_dict, 'torch_state_dict.safetensors') + +# load it +import torch +import safetensors.torch as st + +model = torch.nn.Sequential( + torch.nn.Linear(128, 5), + torch.nn.Linear(128, 5), + torch.nn.Linear(128, 5) +) + +state_dict = st.load_model(model, 'torch_state_dict.safetensors') +``` ## Customizing Serialization @@ -151,8 +319,6 @@ to join the prefix to the keys of the state dict. ```python def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> T: ... - - ``` ## API Reference diff --git a/pyproject.toml b/pyproject.toml index 1d40bf8..bcbb8b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,8 @@ dependencies = [ "equinox>=0.10.6", "jaxtyping>=0.2.20", "jmp>=0.0.4", - "safetensors>=0.4.3" + "safetensors>=0.4.3", + "optax>=0.2.4", ] dynamic =[ "version" ] diff --git a/src/haliax/_src/rearrange.py b/src/haliax/_src/rearrange.py index ce6e5ef..08e345e 100644 --- a/src/haliax/_src/rearrange.py +++ b/src/haliax/_src/rearrange.py @@ -114,7 +114,8 @@ def axis_spec_rearrange(array: NamedArray, axis_spec: PartialAxisSpec) -> NamedA permute_spec.append(index_of_in[ax]) out_axes = tuple(array.axes[i] for i in typing.cast(list[int], permute_spec)) - return NamedArray(jnp.transpose(array.array, permute_spec), out_axes) + transpose = jnp.transpose(array.array, permute_spec) if array.array is not None else None + return NamedArray(transpose, out_axes) def einops_rearrange(array: NamedArray, expression: str, **bindings: AxisSelector | int) -> NamedArray: diff --git a/src/haliax/_src/state_dict.py b/src/haliax/_src/state_dict.py index e8b7a7d..0b181f9 100644 --- a/src/haliax/_src/state_dict.py +++ b/src/haliax/_src/state_dict.py @@ -1,23 +1,26 @@ -# Module to support torch-style "state dict" serialization via safetensors +# Module to support the creation of "state dict", including supporting Torch-compatible serialization via safetensors import dataclasses import typing -from typing import Any, Optional, Sequence, TypeVar +from typing import Any, Optional, Sequence, TypeVar, cast +import equinox import equinox as eqx import jax -import jax.numpy as jnp import numpy as np from jax import ShapeDtypeStruct +from jax import numpy as jnp from jax.experimental.multihost_utils import sync_global_devices from jax.sharding import Mesh, NamedSharding, PartitionSpec from jax.tree_util import DictKey, FlattenedIndexKey, GetAttrKey, SequenceKey from jaxtyping import PyTree +import haliax import haliax.partitioning as partitioning from haliax._src.util import index_where from haliax.axis import Axis from haliax.core import NamedArray, flatten_axes, named from haliax.jax_utils import is_jax_array_like, is_scalarish +from haliax.types import FilterSpec try: @@ -31,40 +34,6 @@ T = TypeVar("T") -def from_torch_compatible_state_dict( - t: T, state_dict: StateDict, *, unflatten_linear: bool = True, prefix: Optional[str] = None -) -> T: - """ - Convert a state dict to a tree that is compatible with the structure of `t`. - - This applies [haliax.state_dict.from_state_dict][] followed by [haliax.state_dict.unflatten_linear_layers][]. - """ - if unflatten_linear: - t = _flatten_to_unflatten(t, state_dict, prefix) - else: - t = from_state_dict(t, state_dict, prefix=prefix) - - return t - - -def _flatten_to_unflatten(t, state_dict, prefix): - """ - Flatten the torch compatible state_dict before loading into t, and then recover the unflattened layers. - """ - # typically, `t` is a bunch of ShapeDtypeStructs, which can't be transposed etc. so we instead have to zeros() - # into real arrays (that aren't actually real b/c this is inside a jit) - def _dt_struct_to_array(struct): - if not isinstance(struct, ShapeDtypeStruct): - return struct - return jnp.zeros(struct.shape, struct.dtype) - - t = jax.tree.map(_dt_struct_to_array, t) - flat_t = flatten_linear_layers(t) - flat_t = from_state_dict(flat_t, state_dict, prefix=prefix) - t = unflatten_linear_layers(t, flat_t) - return t - - @typing.overload def with_prefix(prefix: str | None, leaf: str) -> str: ... @@ -81,7 +50,7 @@ def with_prefix(prefix: Optional[str], leaf: Optional[str]) -> Optional[str]: def with_prefix(prefix: Optional[str], leaf: Optional[str]) -> Optional[str]: - """Joins two optional path strings in a way compatible with pytorch state dict serialization""" + """Joins two optional path strings in a way compatible with PyTorch state dict serialization""" if prefix is None: return leaf elif leaf is None: @@ -91,7 +60,7 @@ def with_prefix(prefix: Optional[str], leaf: Optional[str]) -> Optional[str]: class ModuleWithStateDictSerialization(eqx.Module): - """An eqx.Module that can be serialized to a torch-style state dict.""" + """An eqx.Module that can be serialized to a state dict.""" def to_state_dict(self, prefix: Optional[str] = None) -> StateDict: return default_eqx_module_to_state_dict(self, prefix) @@ -130,7 +99,7 @@ def from_state_dict(tree: T, state_dict: StateDict, prefix: Optional[str] = None return {k: from_state_dict(v, state_dict, prefix=with_prefix(prefix, k)) for k, v in tree.items()} # type: ignore elif isinstance(tree, NamedArray): if prefix is None: - raise ValueError("Cannot extract a leaf value from a torch dict without a prefix") + raise ValueError("Cannot extract a leaf value from a state dict without a prefix") array = state_dict[prefix] @@ -160,15 +129,25 @@ def from_state_dict(tree: T, state_dict: StateDict, prefix: Optional[str] = None return state_dict.get(prefix, tree) -def to_state_dict(tree: PyTree, prefix: Optional[str] = None) -> StateDict: +def to_state_dict(tree: PyTree, prefix: Optional[str] = None, *, is_leaf: typing.Callable | None = None) -> StateDict: """ Convert a PyTree to a state dict. + Args: + tree: The tree to convert. + prefix: The prefix to use for the keys in the state dict. + is_leaf: A function that determines whether a node in the tree is a leaf. In addition to this function, + NamedArrays and the built-in JAX types are considered leaves. + Returns: The state dict representation of the input tree. """ if tree is None: return {} + elif is_leaf is not None and is_leaf(tree): + if prefix is None: + raise ValueError("Cannot convert a leaf value to a state dict without a prefix") + return {prefix: tree} elif isinstance(tree, eqx.Module): if hasattr(tree, "to_state_dict"): state_dict = tree.to_state_dict(prefix) @@ -374,6 +353,95 @@ def load_state_dict(path): return state_dict +def to_torch_compatible_state_dict( + t: T, + *, + flatten_linear: bool = True, + unstack_stacked: bool = True, + prefix: Optional[str] = None, + filter: FilterSpec = is_jax_array_like, +) -> StateDict: + """ + Convert a tree to a state dict that is compatible with Torch-style state dicts. + + "Torch-style" here means two things: + + 1. [haliax.nn.Stacked][] instances are unstacked into a list of modules (instead of being represented as a single + module with a vmapped set of leaves.) This means they can be read into torch.nn.Sequential instances. + 2. Linear layers are represented in their flattened form, i.e. as a 2d matrix and 1d bias vector (instead of the + more structure form we support). This means they can be read into torch.nn.Linear instances. + + This applies [haliax.state_dict.flatten_linear_layers][] followed by [haliax.state_dict.to_state_dict][] + + Args: + t: The tree to convert + flatten_linear: Whether to flatten linear layers + unstack_stacked: Whether to unstack stacked modules + prefix: The prefix to use for the state dict keys + filter: The filter to use for selecting which nodes to include in the state dict. By default, this includes only + array-like objects (e.g. JAX and NumPy arrays). + """ + t = equinox.filter(t, filter) + if unstack_stacked: + t = _unstack_stacked(t) + if flatten_linear: + t = flatten_linear_layers(t) + return to_numpy_state_dict(t, prefix=prefix) + + +def from_torch_compatible_state_dict( + t: T, + state_dict: StateDict, + *, + unflatten_linear: bool = True, + restack_stacked: bool = True, + prefix: Optional[str] = None, +) -> T: + """ + Convert a state dict to a tree that is compatible with the structure of `t`. + + This function is the inverse of [haliax.state_dict.to_torch_compatible_state_dict][]. It has two (configurable) + behaviors on top of to_state_dict: + + 1) Linear layers are unflattened, i.e. converted from a 2d matrix and 1d bias vector to the more structured form + present in the input tree. + 2) [haliax.nn.Stacked][] instances are restacked, i.e. converted from a list of modules to a single module with + a vmapped set of leaves. + """ + # This function is a bit weird internally: it has to recreate the flattened/unstacked state dict + # so that it can be passed to from_state_dict. Then we undo the flattening/unstacking. + + if restack_stacked or unflatten_linear: + # typically, `t` is a bunch of ShapeDtypeStructs, which can't be transposed etc. so we instead have to zeros() + # into real arrays (that aren't actually real b/c this is inside a jit) + def _dt_struct_to_array(struct): + if not isinstance(struct, ShapeDtypeStruct): + return struct + return jnp.zeros(struct.shape, struct.dtype) + + t = jax.tree.map(_dt_struct_to_array, t) + + orig_t = t + + if restack_stacked: + t = _unstack_stacked(t) + + pre_flatten = t + + if unflatten_linear: + t = flatten_linear_layers(t) + + t = from_state_dict(t, state_dict, prefix=prefix) + + if unflatten_linear: + t = unflatten_linear_layers(pre_flatten, t) + + if restack_stacked: + t = _restack_stacked(orig_t, t) + + return t + + def flatten_linear_layers(tree: T) -> T: """ In PyTorch, linear layers are stored as a 2d weight matrix and a 1d bias vector. In Haliax, @@ -394,6 +462,10 @@ def _flatten_linear(layer): new_Out: Axis = flatten_axes(layer.Out, "__OUT__") new_In: Axis = flatten_axes(layer.In, "__IN__") + # TODO: ensure sharding? + # out_pspec = haliax.partitioning.pspec_for_axis(layer.Out) + # in_pspec = haliax.partitioning.pspec_for_axis(layer.In) + if weight.array is not None: out_first = layer.out_first weight = weight.flatten_axes(layer.Out, new_Out).flatten_axes(layer.In, new_In) @@ -447,3 +519,63 @@ def _unflatten_linear(template, flattened): return jax.tree.map( _unflatten_linear, template, tree_with_flattened_linears, is_leaf=lambda x: isinstance(x, Linear) ) + + +def _unstack_stacked(tree: PyTree) -> PyTree: + """ + Unstack all [haliax.nn.Stacked][] instances in a tree, returning a new tree with all Stacked instances + converted to BlockSeq instances. + """ + from haliax.nn import Stacked + + def _unstack_layer(layer): + if not isinstance(layer, Stacked): + return layer + + bs_layer = layer.as_block_seq() + return _unstack_stacked(bs_layer) + + return jax.tree.map(_unstack_layer, tree, is_leaf=lambda x: isinstance(x, Stacked)) + + +def _restack_stacked(template: PyTree, tree: PyTree) -> PyTree: + """ + Restack all [haliax.nn.Stacked][] instances in a tree, returning a new tree with all BlockSeq instances + converted to Stacked instances. + + This implementation could be cleverer in the way it handles nested stacks. but those are rare enough that + it's not worth the complexity. + """ + from haliax.nn import BlockSeq, Stacked + + def _restack_layer(template, layer): + if not isinstance(template, Stacked) or not isinstance(layer, BlockSeq): + return layer + + Block = template.Block + assert template.Block == layer.Block, "Block mismatch" + + # first handle the recursion. + # the recursion is actually surprsingly stricky. each block in layer.blocks must be handled + # separately + template_0 = template.get_block(0) + layer_blocks = [_restack_stacked(template_0, block) for block in layer.blocks] + + def restack_tree_leaf(leaves): + if isinstance(leaves[0], NamedArray): + return haliax.stack(Block, leaves) + + if is_jax_array_like(leaves[0]): + return jnp.stack(leaves, axis=0) + + if not all(leaf == leaves[0] for leaf in leaves): + raise ValueError(f"Unsupported type for restacking {type(leaves[0])}") + + return leaves[0] + + restacked = haliax.tree_util.tree_map(lambda *leaves: restack_tree_leaf(leaves), *layer_blocks) + return Stacked( + restacked, Block, gradient_checkpointing=layer.gradient_checkpointing, prevent_cse=template.prevent_cse + ) + + return jax.tree.map(_restack_layer, template, tree, is_leaf=lambda x: isinstance(x, Stacked)) diff --git a/src/haliax/core.py b/src/haliax/core.py index e8acf19..0aca90d 100644 --- a/src/haliax/core.py +++ b/src/haliax/core.py @@ -1093,7 +1093,10 @@ def unbind(array: NamedArray, axis: AxisSelector) -> List[NamedArray]: # arrays = jnp.rollaxis(array.array, axis=axis_index, start=0) # instead we just loop over the axes pulling one out at a time axis_size = array.axes[axis_index].size - arrays = [jnp.take(array.array, i, axis=axis_index) for i in range(axis_size)] + if array.array is None: + arrays = [None] * axis_size + else: + arrays = [jnp.take(array.array, i, axis=axis_index) for i in range(axis_size)] return [haliax.auto_sharded(NamedArray(a, new_axes)) for a in arrays] @@ -1236,7 +1239,9 @@ def _full_flatten( new_axes.append(ax) array = array.rearrange(intermediate_axes) - raw_array = array.array.reshape([ax.size for ax in new_axes]) + raw_array = array.array + if raw_array is not None: + raw_array = raw_array.reshape([ax.size for ax in new_axes]) return NamedArray(raw_array, tuple(new_axes)) diff --git a/src/haliax/nn/scan.py b/src/haliax/nn/scan.py index 6dc04db..4a57b3d 100644 --- a/src/haliax/nn/scan.py +++ b/src/haliax/nn/scan.py @@ -1,15 +1,12 @@ import functools -import re -from typing import Any, Dict, Generic, Optional, Protocol, Sequence, Type, TypeVar, cast +from typing import Dict, Generic, Optional, Protocol, Sequence, Type, TypeVar import equinox as eqx import jax -from jax import numpy as jnp import haliax import haliax.util -from haliax.jax_utils import filter_checkpoint, is_jax_array_like -from haliax.util import is_jax_or_hax_array_like +from haliax.jax_utils import filter_checkpoint from .._src.state_dict import ModuleWithStateDictSerialization, StateDict, with_prefix from ..axis import Axis @@ -36,6 +33,7 @@ class BlockFoldable(Protocol[M]): """ Block: Axis + gradient_checkpointing: bool @classmethod def init( @@ -56,6 +54,14 @@ def unstacked(self) -> Sequence[M]: """ ... + def as_block_seq(self) -> "BlockSeq[M]": + """ + Convert this module to a BlockSeq. This is useful if you have a Stacked module and you want to convert it to a + BlockSeq, e.g. for saving checkpoints or logging. + + """ + ... + class BlockSeq(ModuleWithStateDictSerialization, Generic[M]): """ @@ -156,10 +162,10 @@ def from_state_dict(self: M, state_dict: StateDict, prefix: Optional[str] = None out_blocks = [] for i, block in enumerate(self.blocks): my_prefix = with_prefix(prefix, str(i)) - block = block.from_state_dict(state_dict, my_prefix) + block = haliax.state_dict.from_state_dict(block, state_dict, my_prefix) out_blocks.append(block) - return eqx.tree_at(lambda m: m.blocks, self, out_blocks) + return eqx.tree_at(lambda m: m.blocks, self, type(self.blocks)(out_blocks)) def to_state_dict(self, prefix: Optional[str] = None) -> StateDict: """ @@ -168,12 +174,14 @@ def to_state_dict(self, prefix: Optional[str] = None) -> StateDict: state_dict: StateDict = {} for i, block in enumerate(self.blocks): my_prefix = with_prefix(prefix, str(i)) - # we can't assume to_state_dict is implemented, so we have to do it manually block_dict = haliax.state_dict.to_state_dict(block, my_prefix) state_dict.update(block_dict) return state_dict + def as_block_seq(self) -> "BlockSeq[M]": + return self + class Stacked(ModuleWithStateDictSerialization, Generic[M]): """ @@ -318,10 +326,44 @@ def fold(self, init, *args, **kwargs): def _do_block(carry, block, *extra_args, **extra_kwargs): return block(carry, *extra_args, **extra_kwargs) - # TODO: this is for logic that's in levanter. We should move that logic to haliax I guess? def _state_dict_key_map(self) -> Dict[str, Optional[str]]: return {"stacked": None} + def get_block(self, i: int) -> M: + """ + Get the ith block from the stacked module. You don't typically use this when doing computation, + but it's useful for debugging or introspection, or unstacking. + """ + + def get_block_leaf(leaf): + if isinstance(leaf, haliax.core.NamedArray): + return leaf[self.Block, i] + elif haliax.jax_utils.is_jax_array_like(leaf): + return leaf[i] + else: + return leaf + + return haliax.tree_util.tree_map(get_block_leaf, self.stacked) + + def set_block(self, i: int, block: M) -> "Stacked[M]": + """ + Set the ith block in the stacked module, returning a new Stacked module. + + Returns: + this module with the ith block replaced with the new block + """ + + def set_block_leaf(leaf): + if isinstance(leaf, haliax.core.NamedArray): + return leaf.at[self.Block, i].set(block) + elif haliax.jax_utils.is_jax_array_like(leaf): + return leaf.at[i].set(block) + else: + return block + + new_stacked = haliax.tree_util.tree_map(set_block_leaf, self.stacked) + return eqx.tree_at(lambda m: m.stacked, self, new_stacked) + def unstacked(self) -> Sequence[M]: """ Returns the unstacked version of this module. This is useful for logging or saving checkpoints. @@ -349,77 +391,5 @@ def unbatch_leaf(x): unstacked_leaves = tuple(zip(*unstacked_leaves)) return tuple(map(lambda x: jax.tree_util.tree_unflatten(structure, x), unstacked_leaves)) - def to_state_dict(self, prefix: Optional[str] = None) -> StateDict: - # this method needs to "devectorize" the blocks, so that we have a list of blocks h.0.FOO, h.1.FOO, etc. - # first just do the normal thing with our own dict, which we'll post-process - state_dict: StateDict = super().to_state_dict(prefix) - - return _unstack_state_dict(state_dict, prefix) - - def from_state_dict(self: M, state_dict: StateDict, prefix: Optional[str] = None) -> M: - # this method needs to "vectorize" the blocks, so that we have a single block h.FOO - # first just do the normal thing with our own dict, which we'll post-process - stacked = _stack_state_dict(state_dict, prefix=prefix) - out = super().from_state_dict(stacked, prefix=prefix) # type: ignore - return out - - -def _stack_state_dict(state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: - """ - Stack all keys matching prefix in a new state dict, returning a state dict that has all keys matching - prefix stacked, but otherwise the same. - - Stacked in this case means roughly "compatible with a torch.nn.Sequential", which means that the - keys are of the form ".0.", ".1.", etc. - - Mostly for use with [haliax.nn.Stacked][]. - """ - vectorized_dict: StateDict = {} - - tensors_to_vectorize: dict[str, list[Optional[Any]]] = {} - if prefix is not None: - prefix_for_pat = re.escape(prefix + ".") - else: - prefix_for_pat = "" - pattern = re.compile(rf"{prefix_for_pat}(\d+)\.(.*)") - - for k, v in state_dict.items(): - match = pattern.match(k) - if match: - block_idx = int(match.group(1)) - block_key = match.group(2) - tensors = tensors_to_vectorize.setdefault(block_key, []) - if len(tensors) <= block_idx: - tensors.extend([None] * (block_idx - len(tensors) + 1)) - assert tensors[block_idx] is None, f"Duplicate key {k}" - tensors[block_idx] = v - else: - vectorized_dict[k] = v - - # now we have to vectorize the tensors - for k, tensors in tensors_to_vectorize.items(): - vectorized_dict[cast(str, with_prefix(prefix, k))] = jnp.stack(tensors, axis=0) - - return vectorized_dict - - -def _unstack_state_dict(state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: - """ - Unstack all keys matching prefix in a new state dict, returning a state dict that has all keys matching - prefix unstacked, but otherwise the same. Mostly for use with [haliax.nn.Stacked][]. - - Unstacked in this case means roughly "compatible with a torch.nn.Sequential", which means that the - keys are of the form ".0.", ".1.", etc. - """ - new_dict: StateDict = {} - prefix = with_prefix(prefix, "") - assert prefix is not None - - for k, v in state_dict.items(): - if k.startswith(prefix) and is_jax_or_hax_array_like(v): - for i, v_i in enumerate(v): - new_dict[f"{prefix}{i}.{k[len(prefix):]}"] = v_i - else: - new_dict[k] = v - - return new_dict + def as_block_seq(self) -> BlockSeq[M]: + return BlockSeq(self.unstacked(), self.Block, self.gradient_checkpointing) diff --git a/src/haliax/quantization.py b/src/haliax/quantization.py index 9a3edcc..bcad64d 100644 --- a/src/haliax/quantization.py +++ b/src/haliax/quantization.py @@ -253,7 +253,7 @@ def _matches_target_fp8(key_path, config: Fp8Config) -> bool: return re.match(config.targets, key_path_str) is not None -def _key_path_to_str(key_path: tuple[BuiltInKeyEntry, ...]) -> str: +def _key_path_to_str(key_path) -> str: out = "" for k in key_path: match k: diff --git a/src/haliax/state_dict.py b/src/haliax/state_dict.py index fdf7a70..807400f 100644 --- a/src/haliax/state_dict.py +++ b/src/haliax/state_dict.py @@ -1,10 +1,3 @@ -from typing import Optional, TypeVar - -import equinox - -from haliax.jax_utils import is_jax_array_like -from haliax.types import FilterSpec - from ._src.state_dict import ( ModuleWithStateDictSerialization, StateDict, @@ -15,38 +8,14 @@ save_state_dict, to_numpy_state_dict, to_state_dict, + to_torch_compatible_state_dict, unflatten_linear_layers, with_prefix, ) -T = TypeVar("T") - - -def to_torch_compatible_state_dict( - t: T, *, flatten_linear: bool = True, prefix: Optional[str] = None, filter: FilterSpec = is_jax_array_like -) -> StateDict: - """ - Convert a tree to a state dict that is compatible with torch-style state dicts. - - This applies [haliax.state_dict.flatten_linear_layers][] followed by [haliax.state_dict.to_state_dict][] - - Args: - t: The tree to convert - flatten_linear: Whether to flatten linear layers - prefix: The prefix to use for the state dict keys - filter: The filter to use for selecting which nodes to include in the state dict. By default, this includes only - array-like objects (e.g. JAX and NumPy arrays). - """ - t = equinox.filter(t, filter) - if flatten_linear: - t = flatten_linear_layers(t) - return to_numpy_state_dict(t, prefix=prefix) - - __all__ = [ "ModuleWithStateDictSerialization", - "from_torch_compatible_state_dict", "load_state_dict", "save_state_dict", "from_state_dict", @@ -57,4 +26,5 @@ def to_torch_compatible_state_dict( "to_numpy_state_dict", "StateDict", "to_torch_compatible_state_dict", + "from_torch_compatible_state_dict", ] diff --git a/tests/test_state_dict.py b/tests/test_state_dict.py index fe08962..ef14f86 100644 --- a/tests/test_state_dict.py +++ b/tests/test_state_dict.py @@ -4,10 +4,11 @@ import jax import jax.numpy as jnp import pytest +from optax.tree_utils import tree_zeros_like import haliax as hax -from haliax.nn import Linear -from haliax.nn.scan import _stack_state_dict, _unstack_state_dict +from haliax._src.state_dict import _restack_stacked, _unstack_stacked +from haliax.nn import BlockSeq, Linear, Stacked from haliax.state_dict import flatten_linear_layers, from_state_dict, to_state_dict, unflatten_linear_layers @@ -47,65 +48,6 @@ def test_flatten_linear_layers(out_dims_first: bool): assert linear == new_linear -# Test cases for stack_state_dict -@pytest.mark.parametrize( - "input_dict, prefix, expected_output", - [ - # Single block stacking - ( - { - "block.0.weight": jnp.array([1, 2]), - "block.0.bias": jnp.array([3]), - "block.1.weight": jnp.array([4, 5]), - "block.1.bias": jnp.array([6]), - }, - "block", - { - "block.weight": jnp.array([[1, 2], [4, 5]]), - "block.bias": jnp.array([[3], [6]]), - }, - ), - # Mixed data types and unmatched items remain unchanged - ( - { - "block.0.weight": jnp.array([1, 2]), - "block.0.bias": jnp.array([3]), - "block.1.weight": jnp.array([4, 5]), - "block.1.bias": jnp.array([6.0]), - "unrelated.item": jnp.array([7]), - }, - "block", - { - "block.weight": jnp.array([[1, 2], [4, 5]]), - "block.bias": jnp.array([[3.0], [6.0]]), - "unrelated.item": jnp.array([7]), - }, - ), - # No items match prefix, all items should remain unchanged - ( - { - "module.0.param": jnp.array([1]), - "module.1.param": jnp.array([2]), - }, - "block", - { - "module.0.param": jnp.array([1]), - "module.1.param": jnp.array([2]), - }, - ), - ], -) -def test_stack_state_dict(input_dict, prefix, expected_output): - result = _stack_state_dict(input_dict, prefix) - for key in expected_output: - assert jnp.all(jnp.array_equal(result[key], expected_output[key])), f"Failed on key: {key}" - - # now unstack it - unstacked = _unstack_state_dict(result, prefix) - for key in input_dict: - assert jnp.all(jnp.array_equal(unstacked[key], input_dict[key])), f"Failed on key: {key}" - - class M(eqx.Module): a: Any b: Any @@ -127,3 +69,127 @@ def test_to_from_state_dict(): m2 = from_state_dict(m2, state_dict) assert jnp.all(m2.a == a) assert jnp.all(m2.b == b) + + +class Module(eqx.Module): + named: hax.NamedArray + array: jax.Array + static: int = eqx.static_field() + + def __call__(self, x, *, key): + return x + self.array + self.static + + @staticmethod + def init(named, array, static): + return Module(named=named, array=array, static=static) + + +class Mod2(eqx.Module): + a: Stacked[Module] + + @staticmethod + def init(Block2, named, array, static): + return Mod2(a=Stacked.init(Block2, Module)(named=named, array=array, static=static)) + + +def test_tree_unstacking(): + Block = hax.Axis("block", 4) + E = hax.Axis("E", 10) + + initial_named = hax.random.uniform(jax.random.PRNGKey(0), (Block, E)) + + m = Stacked.init(Block, Module)(named=initial_named, array=jax.numpy.ones(Block.size), static=1) + + assert m.stacked.named.axes == (Block, E) + assert m.stacked.array.shape == (Block.size,) + assert m.stacked.static == 1 + + unstacked = _unstack_stacked(m) + + assert isinstance(unstacked, BlockSeq) + + z = tree_zeros_like(m) + + restacked = _restack_stacked(z, unstacked) + + assert restacked == m + + +def test_double_stacking(): + + Block1 = hax.Axis("Block1", 4) + Block2 = hax.Axis("Block2", 2) + + E = hax.Axis("E", 10) + + initial_named = hax.random.uniform(jax.random.PRNGKey(0), (Block1, Block2, E)) + + m_stacked = Stacked.init(Block1, Mod2)( + Block2, named=initial_named, array=jax.numpy.ones((Block1.size, Block2.size)), static=1 + ) + + m_unstacked = _unstack_stacked(m_stacked) + + # ensure there are no stacked left + leaves = jax.tree.leaves(m_unstacked, is_leaf=lambda x: isinstance(x, Stacked)) + assert not any(isinstance(leaf, Stacked) for leaf in leaves) + + m_restacked = _restack_stacked(tree_zeros_like(m_stacked), m_unstacked) + + assert m_stacked == m_restacked + + +def test_torch_compatible_state_dict_stacked(): + Block1 = hax.Axis("Block1", 4) + Block2 = hax.Axis("Block2", 2) + + E = hax.Axis("E", 10) + + initial_named = hax.random.uniform(jax.random.PRNGKey(0), (Block1, Block2, E)) + + m_stacked = Stacked.init(Block1, Mod2)( + Block2, named=initial_named, array=jax.numpy.ones((Block1.size, Block2.size)), static=1 + ) + + state_dict = hax.state_dict.to_torch_compatible_state_dict(m_stacked) + + # check for some keys: + assert "0.a.0.array" in state_dict + assert "1.a.1.named" in state_dict + + z = tree_zeros_like(m_stacked) + + m_unstacked = hax.state_dict.from_torch_compatible_state_dict(z, state_dict) + + assert m_stacked == m_unstacked + + +def test_torch_compatible_state_dict_stacked_linear(): + Block1 = hax.Axis("Block1", 4) + Block2 = hax.Axis("Block2", 2) + + E = hax.Axis("E", 10) + E2 = hax.Axis("E2", 5) + + class ModLinear(eqx.Module): + a: hax.nn.Stacked[hax.nn.Linear] + + @staticmethod + def init(Block2, key): + return ModLinear(a=hax.nn.Stacked.init(Block2, hax.nn.Linear)(E, E2, key=key)) + + m_stacked = Stacked.init(Block1, ModLinear)( + Block2, key=jax.random.split(jax.random.PRNGKey(1), (Block1.size, Block2.size)) + ) + + state_dict = hax.state_dict.to_torch_compatible_state_dict(m_stacked) + + # check for some keys: + assert "0.a.0.bias" in state_dict + assert "1.a.1.weight" in state_dict + + z = tree_zeros_like(m_stacked) + + m_unstacked = hax.state_dict.from_torch_compatible_state_dict(z, state_dict) + + assert m_stacked == m_unstacked