diff --git a/docs/_static/custom_css.css b/docs/_static/custom_css.css index d01752c8..7d81d54e 100644 --- a/docs/_static/custom_css.css +++ b/docs/_static/custom_css.css @@ -1,3 +1,8 @@ +/* Fix /page#foo going to the top of the viewport and being hidden by the navbar */ +html { + scroll-padding-top: 50px; +} + /* Fit the Twitter handle alongside the GitHub one in the top right. */ div.md-header__source { diff --git a/docs/all-of-equinox.md b/docs/all-of-equinox.md index 5da1e9c0..2cd3f6a4 100644 --- a/docs/all-of-equinox.md +++ b/docs/all-of-equinox.md @@ -145,6 +145,6 @@ See also the API reference on the left. !!! faq "FAQ" - One common question that gets asked: a lot of other libraries introduce custom `library.jit` etc. operations, specifically to work with `library.Module`. What makes the filtered transformations of Equinox different? + One common question: a lot of other libraries introduce custom `library.jit` etc. operations, specifically to work with `library.Module`. What makes the filtered transformations of Equinox different? - The answer is that filters are tools that apply to *any PyTree*; and models just happen to be PyTrees. There is no special coupling between a filtered transformation and `eqx.Module`. (Not to mention that the purpose of filtered transformations -- being able to include non-JAX-arrays in your model's parameters -- is itself unusual/impossible in many other libraries.) + The answer is that filter transformations are tools that apply to any PyTree. And models just happen to be PyTrees. The filtered transformations and `eqx.Module` are not coupled together. diff --git a/docs/api/stateful.md b/docs/api/experimental/stateful.md similarity index 58% rename from docs/api/stateful.md rename to docs/api/experimental/stateful.md index 6a0f11f4..ff9ff971 100644 --- a/docs/api/stateful.md +++ b/docs/api/experimental/stateful.md @@ -6,6 +6,8 @@ These operations can be used to introduce save/load JAX arrays as a side-effect This is considered experimental. + Stateful operations will not produce correct results under `jax.checkpoint` or `jax.pmap`. + !!! danger Really, **this is experimental**. Side effects can easily make your code do something unexpected. Whatever you're doing, you almost certainly do not need this. @@ -15,33 +17,35 @@ Use cases: - Something like [`equinox.experimental.BatchNorm`][], for which we would like to save the running statistics as a side-effect. - Implicitly passing information between loop iterations -- i.e. rather than explicitly via the `carry` argument to `lax.scan`. Perhaps you're using a third-party library that handles the `lax.scan`, that doesn't allow you pass your own information between iterations. -Example: -```python -import equinox as eqx -import jax -import jax.lax as lax -import jax.numpy as jnp +!!! example + + ```python + import equinox as eqx + import jax + import jax.lax as lax + import jax.numpy as jnp -index = eqx.experimental.StateIndex() -init = jnp.array(0) -eqx.experimental.set_state(index, init) + index = eqx.experimental.StateIndex() + init = jnp.array(0) + eqx.experimental.set_state(index, init) -@jax.jit -def scan_fun(_, __): - val = eqx.experimental.get_state(index, like=init) - val = val + 1 - eqx.experimental.set_state(index, val) - return None, val + @jax.jit + def scan_fun(_, __): + val = eqx.experimental.get_state(index, like=init) + val = val + 1 + eqx.experimental.set_state(index, val) + return None, val -_, out = lax.scan(scan_fun, None, xs=None, length=5) -print(out) # [1 2 3 4 5] -``` + _, out = lax.scan(scan_fun, None, xs=None, length=5) + print(out) # [1 2 3 4 5] + ``` --- ::: equinox.experimental.StateIndex selection: - members: false + members: + - __init__ --- diff --git a/docs/api/filtering/filtered-transformations.md b/docs/api/filtering/filtered-transformations.md index 22a3e7b3..d1338c38 100644 --- a/docs/api/filtering/filtered-transformations.md +++ b/docs/api/filtering/filtered-transformations.md @@ -1,6 +1,6 @@ # Filtered transformations -These typically combine [`equinox.partition`][], a filter function, and a JAX transformation, all together. +These typically combine [`equinox.partition`][], a [filter function](./filter-functions.md), and a JAX transformation, all together. Practically speaking these are usually the only kind of filtering you ever have to use. (But it's good to understand what e.g. [`equinox.partition`][] and [`equinox.is_array`][] are doing under the hood, just so that these don't seem too magical.) @@ -19,3 +19,11 @@ Practically speaking these are usually the only kind of filtering you ever have ::: equinox.filter_custom_vjp selection: members: false + +--- + +::: equinox.filter_vmap + +--- + +::: equinox.filter_pmap diff --git a/docs/api/nn/pool.md b/docs/api/nn/pool.md new file mode 100644 index 00000000..4cc74695 --- /dev/null +++ b/docs/api/nn/pool.md @@ -0,0 +1,43 @@ +# Pooling + +::: equinox.nn.Pool + selection: + members: + - __init__ + - __call__ + +--- + +::: equinox.nn.AvgPool1D + selection: + members: false + +--- + +::: equinox.nn.AvgPool2D + selection: + members: false + +--- + +::: equinox.nn.AvgPool3D + selection: + members: false + +--- + +::: equinox.nn.MaxPool1D + selection: + members: false + +--- + +::: equinox.nn.MaxPool2D + selection: + members: false + +--- + +::: equinox.nn.MaxPool3D + selection: + members: false diff --git a/docs/api/utilities/manipulation.md b/docs/api/utilities/manipulation.md new file mode 100644 index 00000000..cf9e96ac --- /dev/null +++ b/docs/api/utilities/manipulation.md @@ -0,0 +1,11 @@ +# Manipulating PyTrees + +::: equinox.apply_updates + +--- + +::: equinox.tree_at + +--- + +::: equinox.tree_inference diff --git a/docs/api/helpers.md b/docs/api/utilities/miscellaneous.md similarity index 51% rename from docs/api/helpers.md rename to docs/api/utilities/miscellaneous.md index ac51ecf1..9fae6d42 100644 --- a/docs/api/helpers.md +++ b/docs/api/utilities/miscellaneous.md @@ -1,12 +1,4 @@ -# Helpers for PyTrees - -::: equinox.apply_updates - ---- - -::: equinox.tree_at - ---- +# Miscellaneous ::: equinox.tree_pformat diff --git a/docs/api/utilities/serialisation.md b/docs/api/utilities/serialisation.md new file mode 100644 index 00000000..67655095 --- /dev/null +++ b/docs/api/utilities/serialisation.md @@ -0,0 +1,7 @@ +# Serialisation + +::: equinox.tree_serialise_leaves + +--- + +::: equinox.tree_deserialise_leaves diff --git a/docs/faq.md b/docs/faq.md index 66fabfc8..3c9f8796 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -37,11 +37,11 @@ Recall that in Equinox, models are PyTrees. Meanwhile, JAX treats all PyTrees as The resolution is simple: just don't store the same object in multiple places in the PyTree. -## I cannot feed higher-order tensors (e.g. with batch dimensions) into my model. +## How do I input higher-order tensors (e.g. with batch dimensions) into my model? -Use [`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap). This maps arbitrary JAX operations -- including any Equinox module -- over additional dimensions. +Use [`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap). This maps arbitrary JAX operations -- including any Equinox module -- over additional dimensions (such as batch dimensions). -For example if `x` is an array/tensor of shape `(batch_size, input_size)`, the following code in PyTorch: +For example if `x` is an array/tensor of shape `(batch_size, input_size)`, then the following PyTorch code: ```python import torch @@ -50,8 +50,7 @@ linear = torch.nn.Linear(input_size, output_size) y = linear(x) ``` -is equivalent to the following code in Equinox: - +is equivalent to the following Equinox code: ```python import jax import equinox as eqx diff --git a/docs/requirements.txt b/docs/requirements.txt index e304ee81..608ae930 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -4,7 +4,7 @@ mkdocs-material==7.3.6 # Theme pymdown-extensions==9.4 # Markdown extensions e.g. to handle LaTeX. mkdocstrings==0.17.0 # Autogenerate documentation from docstrings. mknotebooks==0.7.1 # Turn Jupyter Lab notebooks into webpages. -pytkdocs_tweaks==0.0.4 # Tweaks mkdocstrings to improve various aspects +pytkdocs_tweaks==0.0.5 # Tweaks mkdocstrings to improve various aspects mkdocs_include_exclude_files==0.0.1 # Tweak which files are included/excluded jinja2==3.0.3 # Older version. After 3.1.0 seems to be incompatible with current versions of mkdocstrings. diff --git a/equinox/__init__.py b/equinox/__init__.py index c841b770..771dcdaa 100644 --- a/equinox/__init__.py +++ b/equinox/__init__.py @@ -12,8 +12,10 @@ from .jit import filter_jit from .module import Module, static_field from .pretty_print import tree_pformat -from .tree import tree_at, tree_equal +from .serialisation import tree_deserialise_leaves, tree_serialise_leaves +from .tree import tree_at, tree_equal, tree_inference from .update import apply_updates +from .vmap_pmap import filter_pmap, filter_vmap -__version__ = "0.4.0" +__version__ = "0.5.0" diff --git a/equinox/compile_utils.py b/equinox/compile_utils.py new file mode 100644 index 00000000..18059145 --- /dev/null +++ b/equinox/compile_utils.py @@ -0,0 +1,46 @@ +import functools as ft +from typing import Any + +import jax + +from .filters import combine, partition +from .module import Module, static_field + + +def hashable_partition(pytree, filter_spec): + dynamic, static = partition(pytree, filter_spec) + static_leaves, static_treedef = jax.tree_flatten(static) + static_leaves = tuple(static_leaves) + return dynamic, static_leaves, static_treedef + + +def hashable_combine(dynamic, static_leaves, static_treedef): + static = jax.tree_unflatten(static_treedef, static_leaves) + return combine(dynamic, static) + + +class Static(Module): + value: Any = static_field() + + +def strip_wrapped_partial(fun): + if hasattr(fun, "__wrapped__"): # ft.wraps + return strip_wrapped_partial(fun.__wrapped__) + if isinstance(fun, ft.partial): + return strip_wrapped_partial(fun.func) + return fun + + +def compile_cache(fun): + @ft.lru_cache(maxsize=None) + def _cache(leaves, treedef): + args, kwargs = jax.tree_unflatten(treedef, leaves) + return fun(*args, **kwargs) + + @ft.wraps(fun) + def _fun(*args, **kwargs): + leaves, treedef = jax.tree_flatten((args, kwargs)) + leaves = tuple(leaves) + return _cache(leaves, treedef) + + return _fun diff --git a/equinox/custom_types.py b/equinox/custom_types.py index fcda0d1a..70e3e0b6 100644 --- a/equinox/custom_types.py +++ b/equinox/custom_types.py @@ -1,9 +1,11 @@ import inspect import typing -from typing import Generic, Tuple, TypeVar, Union +from typing import Any, Callable, Generic, Tuple, TypeVar, Union import jax +from .doc_utils import doc_repr + # Custom flag we set when generating documentation. # We do a lot of custom hackery in here to produce nice-looking docs. @@ -105,4 +107,9 @@ def __class_getitem__(cls, item): return PyTree +sentinel = doc_repr(object(), "sentinel") + TreeDef = type(jax.tree_structure(0)) + +ResolvedBoolAxisSpec = bool +BoolAxisSpec = Union[ResolvedBoolAxisSpec, Callable[[Any], ResolvedBoolAxisSpec]] diff --git a/equinox/doc_utils.py b/equinox/doc_utils.py new file mode 100644 index 00000000..873d180f --- /dev/null +++ b/equinox/doc_utils.py @@ -0,0 +1,32 @@ +import typing +from types import FunctionType +from typing import Any + + +# Inherits from type so that _WithRepr instances are types and can be used as +# e.g. Sequence[_WithRepr(...)] +class _WithRepr(type): + def __new__(self, string): + out = super().__new__(self, string, (), {}) + # prevent the custom typing repr from doing the wrong thing + out.__module__ = "builtins" + return out + + def __init__(self, string): + self.string = string + + def __repr__(self): + return self.string + + +def doc_repr(obj: Any, string: str): + if getattr(typing, "GENERATING_DOCUMENTATION", False): + return _WithRepr(string) + else: + return obj + + +def doc_strip_annotations(fn: FunctionType) -> FunctionType: + if getattr(typing, "GENERATING_DOCUMENTATION", False): + fn.__annotations__ = None + return fn diff --git a/equinox/experimental/batch_norm.py b/equinox/experimental/batch_norm.py index 9160ad4d..2a35de4a 100644 --- a/equinox/experimental/batch_norm.py +++ b/equinox/experimental/batch_norm.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Sequence, Union import jax import jax.lax as lax @@ -60,7 +60,7 @@ class BatchNorm(Module): bias: Optional[Array["input_size"]] first_time_index: StateIndex state_index: StateIndex - axis_name: str + axis_name: Union[str, Sequence[str]] inference: bool input_size: int = static_field() eps: float = static_field() @@ -81,7 +81,8 @@ def __init__( - `input_size`: The number of channels in the input array. - `axis_name`: The name of the batch axis to compute statistics over, as passed - to `axis_name` in `jax.vmap` or `jax.pmap`. + to `axis_name` in `jax.vmap` or `jax.pmap`. Can also be a sequence (tuple, + list) of strings to compute statistics over multiple named axes. - `eps`: Value added to the denominator for numerical stability. - `channelwise_affine`: Whether the module has learnable channel-wise affine parameters. @@ -89,7 +90,9 @@ def __init__( value between 0 and 1 exclusive. - `inference`: If `False` then the batch means and variances will be calculated and used to update the running statistics. If `True` then the running - statistics are directly used for normalisation. + statistics are directly used for normalisation. This may be toggled with + [`equinox.tree_inference`][] or overridden during + [`equinox.experimental.BatchNorm.__call__`][]. """ super().__init__(**kwargs) diff --git a/equinox/experimental/spectral_norm.py b/equinox/experimental/spectral_norm.py index f2ad00dd..273bab34 100644 --- a/equinox/experimental/spectral_norm.py +++ b/equinox/experimental/spectral_norm.py @@ -135,6 +135,7 @@ def __init__( - `num_power_iterations`: The number of power iterations to apply every time the array is accessed. - `eps`: Epsilon for numerical stability when calculating norms. - `inference`: Whether this is in inference mode, at which time no power iterations are performed. + This may be toggled with [`equinox.tree_inference`][]. - `key`: A `jax.random.PRNGKey` used to provide randomness for initialisation. (Keyword only argument.) """ super().__init__(**kwargs) diff --git a/equinox/experimental/stateful.py b/equinox/experimental/stateful.py index 2b1e34e3..743f603c 100644 --- a/equinox/experimental/stateful.py +++ b/equinox/experimental/stateful.py @@ -73,19 +73,6 @@ def __call__(self, x): ci(x) ci(y) ``` - - !!! warning - - You should not modify the `inference` flag whilst inside a JIT region. For - example, the following will produced undesired behaviour: - - ```python - @jax.jit - def f(...): - ... - index = eqx.tree_at(lambda i: i.inference, index, True) - ... - ``` """ _obj: _IndexObj = static_field(repr=False) @@ -99,7 +86,21 @@ def __init__(self, inference: bool = False): - `inference`: If `True`, then the state can only be get, but not set. All stored states will looked up when crossing the JIT boundary -- rather than dynamically at runtime -- and treated as inputs to the XLA computation - graph. This improves speed at runtime. + graph. This improves speed at runtime. This may be toggled with + [`equinox.tree_inference`][]. + + !!! warning + + You should not modify the `inference` flag whilst inside a JIT region. For + example, the following will produced undefined behaviour: + + ```python + @jax.jit + def f(...): + ... + index = eqx.tree_at(lambda i: i.inference, index, True) + ... + ``` """ self._obj = _IndexObj() self._version = _FixedInt(-1) diff --git a/equinox/filters.py b/equinox/filters.py index bc871da7..50574c3f 100644 --- a/equinox/filters.py +++ b/equinox/filters.py @@ -1,10 +1,10 @@ -from typing import Any, Callable, Union +from typing import Any, Callable, Optional import jax import jax.numpy as jnp import numpy as np -from .custom_types import PyTree +from .custom_types import BoolAxisSpec, PyTree, ResolvedBoolAxisSpec # @@ -17,7 +17,6 @@ def is_array(element: Any) -> bool: return isinstance(element, jnp.ndarray) -# Does _not_ do a try/except on jnp.asarray(element) because that's very slow. # Chosen to match # https://github.com/google/jax/blob/4a17c78605e7fc69a69a999e2f6298db79d3837a/jax/_src/numpy/lax_numpy.py#L542 # noqa: E501 def is_array_like(element: Any) -> bool: @@ -51,17 +50,26 @@ def is_inexact_array_like(element: Any) -> bool: # -def _make_filter_tree(mask: Union[bool, Callable[[Any], bool]], arg: Any) -> bool: - if isinstance(mask, bool): - return mask - elif callable(mask): - return jax.tree_map(mask, arg) - else: - raise ValueError("`filter_spec` must consist of booleans and callables only.") +def _make_filter_tree(is_leaf): + def _filter_tree(mask: BoolAxisSpec, arg: Any) -> ResolvedBoolAxisSpec: + if isinstance(mask, bool): + return jax.tree_map(lambda _: mask, arg, is_leaf=is_leaf) + elif callable(mask): + return jax.tree_map(mask, arg, is_leaf=is_leaf) + else: + raise ValueError( + "`filter_spec` must consist of booleans and callables only." + ) + + return _filter_tree def filter( - pytree: PyTree, filter_spec: PyTree, inverse: bool = False, replace: Any = None + pytree: PyTree, + filter_spec: PyTree[BoolAxisSpec], + inverse: bool = False, + replace: Any = None, + is_leaf: Optional[Callable[[Any], bool]] = None, ) -> PyTree: """ Filters out the leaves of a PyTree not satisfying a condition. Those not satisfying @@ -79,6 +87,11 @@ def filter( - `inverse` switches the truthy/falsey behaviour: falsey results are kept and truthy results are replaced. - `replace` is what to replace any falsey leaves with. Defaults to `None`. + - `is_leaf`: Optional function called at each node of the PyTree. It should return + a boolean. `True` indicates that the whole subtree should be treated as leaf; + `False` indicates that the subtree should be traversed as a PyTree. This is + mostly useful for evaluating a callable `filter_spec` on a node instead of a + leaf. **Returns:** @@ -89,25 +102,30 @@ def filter( A common special case is `equinox.filter(pytree, equinox.is_array)`. Then `equinox.is_array` is evaluted on all of `pytree`'s leaves, and each leaf then kept or replaced. - - !!! info - - See also [`equinox.combine`][] to reconstitute the PyTree again. """ inverse = bool(inverse) # just in case, to make the != trick below work reliably - filter_tree = jax.tree_map(_make_filter_tree, filter_spec, pytree) + filter_tree = jax.tree_map(_make_filter_tree(is_leaf), filter_spec, pytree) return jax.tree_map( lambda mask, x: x if bool(mask) != inverse else replace, filter_tree, pytree ) -def partition(pytree: PyTree, filter_spec: PyTree, replace: Any = None) -> PyTree: +def partition( + pytree: PyTree, + filter_spec: PyTree[BoolAxisSpec], + replace: Any = None, + is_leaf: Optional[Callable[[Any], bool]] = None, +) -> PyTree: """Equivalent to `filter(...), filter(..., inverse=True)`, but slightly more efficient. + + !!! info + + See also [`equinox.combine`][] to reconstitute the PyTree again. """ - filter_tree = jax.tree_map(_make_filter_tree, filter_spec, pytree) + filter_tree = jax.tree_map(_make_filter_tree(is_leaf), filter_spec, pytree) left = jax.tree_map(lambda mask, x: x if mask else replace, filter_tree, pytree) right = jax.tree_map(lambda mask, x: replace if mask else x, filter_tree, pytree) return left, right diff --git a/equinox/grad.py b/equinox/grad.py index f256a0a1..2ae38d01 100644 --- a/equinox/grad.py +++ b/equinox/grad.py @@ -1,18 +1,77 @@ import functools as ft import types import typing +import warnings +from typing import Any, Callable, Dict import jax +from .custom_types import BoolAxisSpec, PyTree, sentinel +from .doc_utils import doc_strip_annotations from .filters import combine, is_array, is_inexact_array, partition +from .module import Module, module_update_wrapper +class _ValueAndGradWrapper(Module): + _fun: Callable + _arg: PyTree[BoolAxisSpec] + _gradkwargs: Dict[str, Any] + + # Try to avoid clashes with existing argument names. + # TODO: use "/" once we're on Python 3.8. + def __call__(__self, __x, *args, **kwargs): + @ft.partial(jax.value_and_grad, argnums=0, **__self._gradkwargs) + def fun_value_and_grad(_diff_x, _nondiff_x, *_args, **_kwargs): + _x = combine(_diff_x, _nondiff_x) + return __self._fun(_x, *_args, **_kwargs) + + diff_x, nondiff_x = partition(__x, __self._arg) + return fun_value_and_grad(diff_x, nondiff_x, *args, **kwargs) + + def __get__(self, instance, owner): + if instance is None: + return self + return jax.tree_util.Partial(self, instance) + + +class _GradWrapper(Module): + _fun_value_and_grad: _ValueAndGradWrapper + _has_aux: bool + + def __call__(__self, *args, **kwargs): + value, grad = __self._fun_value_and_grad(*args, **kwargs) + if __self._has_aux: + _, aux = value + return grad, aux + else: + return grad + + def __get__(self, instance, owner): + if instance is None: + return self + return jax.tree_util.Partial(self, instance) + + +@doc_strip_annotations def filter_value_and_grad( - fun, *, filter_spec=is_inexact_array, argnums=None, **gradkwargs -): + fun: Callable = sentinel, + *, + arg: PyTree[BoolAxisSpec] = is_inexact_array, + **gradkwargs, +) -> Callable: """As [`equinox.filter_grad`][], except that it is `jax.value_and_grad` that is wrapped. """ + + if fun is sentinel: + return ft.partial(filter_value_and_grad, arg=arg, **gradkwargs) + + filter_spec = gradkwargs.pop("filter_spec", None) + if filter_spec is not None: + warnings.warn("For brevity the `filter_spec` argument has been renamed `arg`") + arg = filter_spec + + argnums = gradkwargs.pop("argnums", None) if argnums is not None: raise ValueError( "`argnums` should not be passed. If you need to differentiate " @@ -20,26 +79,28 @@ def filter_value_and_grad( "as the first argument." ) - @ft.partial(jax.value_and_grad, argnums=0, **gradkwargs) - def fun_value_and_grad(diff_x, nondiff_x, *args, **kwargs): - x = combine(diff_x, nondiff_x) - return fun(x, *args, **kwargs) + return module_update_wrapper(_ValueAndGradWrapper(fun, arg, gradkwargs), fun) - @ft.wraps(fun) - def fun_value_and_grad_wrapper(x, *args, **kwargs): - diff_x, nondiff_x = partition(x, filter_spec) - return fun_value_and_grad(diff_x, nondiff_x, *args, **kwargs) - return fun_value_and_grad_wrapper +@doc_strip_annotations +def filter_grad( + fun: Callable = sentinel, + *, + arg: PyTree[BoolAxisSpec] = is_inexact_array, + **gradkwargs, +): + """Wraps together [`equinox.partition`][] and `jax.grad`. + !!! info + + By default, all inexact (floating-point) JAX arrays are differentiated. Any + nondifferentiable leaves will have `None` as the gradient. -def filter_grad(fun, *, filter_spec=is_inexact_array, **gradkwargs): - """Wraps together [`equinox.partition`][] and `jax.grad`. **Arguments:** - `fun` is a pure function to JIT compile. - - `filter_spec` is a PyTree whose structure should be a prefix of the structure of + - `arg` is a PyTree whose structure should be a prefix of the structure of the **first** argument to `fun`. It behaves as the `filter_spec` argument to [`equinox.filter`][]. Truthy values will be differentiated; falsey values will not. @@ -52,14 +113,6 @@ def filter_grad(fun, *, filter_spec=is_inexact_array, **gradkwargs): [`equinox.apply_updates`][] for a convenience function that will only attempt to apply non-`None` updates. - !!! info - - A very important special case is to trace all inexact (i.e. floating point) - JAX arrays and treat all other objects as nondifferentiable. - - This is accomplished with `filter_spec=equinox.is_inexact_array`, which is the - default. - !!! tip If you need to differentiate multiple objects, then put them together into a @@ -76,22 +129,13 @@ def grad_func(x__y): ``` """ - has_aux = gradkwargs.get("has_aux", False) - - fun_value_and_grad = filter_value_and_grad( - fun, filter_spec=filter_spec, **gradkwargs - ) + if fun is sentinel: + return ft.partial(filter_grad, arg=arg, **gradkwargs) - @ft.wraps(fun) - def fun_grad(*args, **kwargs): - value, grad = fun_value_and_grad(*args, **kwargs) - if has_aux: - _, aux = value - return grad, aux - else: - return grad + has_aux = gradkwargs.get("has_aux", False) - return fun_grad + fun_value_and_grad = filter_value_and_grad(fun, arg=arg, **gradkwargs) + return module_update_wrapper(_GradWrapper(fun_value_and_grad, has_aux), fun) class filter_custom_vjp: diff --git a/equinox/jit.py b/equinox/jit.py index 2c7e3b9b..05e39680 100644 --- a/equinox/jit.py +++ b/equinox/jit.py @@ -1,95 +1,200 @@ import functools as ft -from typing import Any +import inspect +import warnings +from types import FunctionType +from typing import Any, Callable, Sequence import jax -from .filters import combine, is_array, partition -from .module import Module, static_field +from .compile_utils import ( + compile_cache, + hashable_combine, + hashable_partition, + Static, + strip_wrapped_partial, +) +from .custom_types import BoolAxisSpec, PyTree, sentinel, TreeDef +from .doc_utils import doc_strip_annotations +from .filters import combine, filter, is_array, partition +from .module import Module, module_update_wrapper -class _Static(Module): - value: Any = static_field() +@compile_cache +def _filter_jit_cache(unwrapped_fun, **jitkwargs): + @ft.partial(jax.jit, static_argnums=1, **jitkwargs) + @ft.wraps(unwrapped_fun) + def fun_wrapped(dynamic, static): + dynamic_fun, dynamic_spec = dynamic + ( + static_fun_treedef, + static_fun_leaves, + static_spec_treedef, + static_spec_leaves, + filter_out, + ) = static + fun = hashable_combine(dynamic_fun, static_fun_leaves, static_fun_treedef) + args, kwargs = hashable_combine( + dynamic_spec, static_spec_leaves, static_spec_treedef + ) + out = fun(*args, **kwargs) + dynamic_out, static_out = partition(out, filter_out) + return dynamic_out, Static(static_out) + + return fun_wrapped -@ft.lru_cache(maxsize=None) -def _f_wrapped_cache(fun, **jitkwargs): - @ft.partial(jax.jit, static_argnums=(1, 2, 3), **jitkwargs) - @ft.wraps(fun) - def f_wrapped(dynamic, static_treedef, static_leaves, filter_spec_return): - static = jax.tree_unflatten(static_treedef, static_leaves) - f, args, kwargs = combine(dynamic, static) - out = f(*args, **kwargs) - dynamic_out, static_out = partition(out, filter_spec_return) - return dynamic_out, _Static(static_out) +class _JitWrapper(Module): + _new_style: bool + _signature: inspect.Signature + _dynamic_fun: PyTree[Any] + _static_fun_treedef: TreeDef + _static_fun_leaves: Sequence[Any] + _filter_default: BoolAxisSpec + _filter_spec: PyTree[BoolAxisSpec] + _filter_out: PyTree[Any] + _cached: FunctionType - return f_wrapped + def _fun_wrapper(self, is_lower, args, kwargs): + if self._new_style: + bound = self._signature.bind(*args, **kwargs) + bound.apply_defaults() + args = bound.args + kwargs = bound.kwargs + filter_args, filter_kwargs = self._filter_spec + filter_args = filter_args + (self._filter_default,) * ( + len(args) - len(filter_args) + ) + filter_kwargs = { + key: filter_kwargs.get(key, self._filter_default) for key in kwargs + } + filter_spec = (filter_args, filter_kwargs) + else: + filter_spec = self._filter_spec + dynamic_spec, static_spec_leaves, static_spec_treedef = hashable_partition( + (args, kwargs), filter_spec + ) + dynamic = (self._dynamic_fun, dynamic_spec) + static = ( + self._static_fun_treedef, + self._static_fun_leaves, + static_spec_treedef, + static_spec_leaves, + self._filter_out, + ) + if is_lower: + return self._cached.lower(dynamic, static) + else: + dynamic_out, static_out = self._cached(dynamic, static) + return combine(dynamic_out, static_out.value) -def _strip_wrapped_partial(fun): - """Preserve the outermost wraps call's docstring or traverse to the inner function""" - if hasattr(fun, "__wrapped__"): # ft.wraps - return _strip_wrapped_partial(fun.__wrapped__) - if isinstance(fun, ft.partial): - return _strip_wrapped_partial(fun.func) - return fun + def __call__(__self, *args, **kwargs): + return __self._fun_wrapper(False, args, kwargs) + def lower(__self, *args, **kwargs): + return __self._fun_wrapper(True, args, kwargs) -def _process_args(args, kwargs, dynamic_fun, static_fun, filter_spec): - dynamic_args_kwargs, static_args_kwargs = partition((args, kwargs), filter_spec) - dynamic = (dynamic_fun,) + dynamic_args_kwargs - static = (static_fun,) + static_args_kwargs - static_leaves, static_treedef = jax.tree_flatten(static) - static_leaves = tuple(static_leaves) - inner_fun = _strip_wrapped_partial(static_fun) - return inner_fun, dynamic, static_treedef, static_leaves + def __get__(self, instance, owner): + if instance is None: + return self + return jax.tree_util.Partial(self, instance) +@doc_strip_annotations def filter_jit( - fun, + fun: Callable = sentinel, *, - filter_spec=is_array, - filter_spec_return=is_array, - filter_spec_fun=is_array, + default: BoolAxisSpec = is_array, + fn: PyTree[BoolAxisSpec] = is_array, + args: PyTree[BoolAxisSpec] = (), + kwargs: PyTree[BoolAxisSpec] = None, + out: PyTree[BoolAxisSpec] = is_array, **jitkwargs -): +) -> Callable: """Wraps together [`equinox.partition`][] and `jax.jit`. + !!! info + + By default, all JAX arrays are traced, and all other types are held static. + **Arguments:** + In each of the following cases, `True` indicates that an argument should be traced, + `False` indicates that an argument should be held static, and functions + `Leaf -> bool` are mapped and evaluated on every leaf of their subtree. + - `fun` is a pure function to JIT compile. - - `filter_spec` is a PyTree whose structure should be a prefix of the structure of - the inputs to `fun`. It behaves as the `filter_spec` argument to - [`equinox.filter`][]. Truthy values will be traced; falsey values will be held - static. - - `filter_spec_return` is a PyTree whose structure should be a prefix of the - structure of the outputs of `fun`. It behaves as the `filter_spec` argument to - [`equinox.filter`][]. Truthy values should be tracers; falsey values are any - (non-tracer) auxiliary information to return. - - `filter_spec_fun` is a PyTree whose structure should be a prefix of `fun` itself. - (Note that `fun` may be any callable -- e.g. a bound method, or a class - implementing `__call__` -- and not necessarily only a function.) It behaves as - the `filter_spec` argument to [`equinox.filter`][]. Truthy values will be - traced; falsey values will be held static. + - `default` should be a `bool` or a function `Leaf -> bool`, and is applied by + default to every argument and keyword argument to `fun`. + - `args` is an optional per-argument override for `default`, and should be a tuple + of PyTrees with leaves that are either `bool`s or functions `Leaf -> bool`. + The PyTree structures should be prefixes of the corresponding input to `fun`. + - `kwargs` is an optional per-keyword-argument override for `default` and should be + a dictionary, whose keys are the names of arguments to `fun`, and whose values + are PyTrees with leaves that either `bool`s or functions `Leaf -> bool`. The + PyTree structures should be prefixes of the corresponding input to `fun`. + - `out` should be a PyTree with leaves that either `bool`s or functions + `Leaf -> bool`. The PyTree structure should be a prefix of the output of `fun`. + Truthy values should be tracers; falsey values are any (non-tracer) auxiliary + information to return. + - `fn` should be a PyTree with leaves that either `bool`s or functions + `Leaf -> bool`. The PyTree structure should be a prefix of `fun` itself. (Note + that `fun` may be any callable, e.g. a bound method, or a class implementing + `__call__`, and doesn't have to be a normal Python function.) - `**jitkwargs` are any other keyword arguments to `jax.jit`. - !!! info - - Specifically, if calling `fun(*args, **kwargs)`, then `filter_spec` must - have a structure which is a prefix for `(args, kwrgs)`. + When `args`, `kwargs`, `out`, `fn` are prefixes of the corresponding input, their + value will be mapped over the input PyTree. **Returns:** The JIT'd version of `fun`. - !!! info + !!! example + + ```python + @eqx.filter_jit + def f(x, y): # both args traced if arrays, static if non-arrays + return x + y + + @eqx.filter_jit(kwargs=dict(x=False)) + def g(x, y): # x held static; y is traced if array, static if non-array + return x + y + + @eqx.filter_jit(args=(True,)) + def h(x): + return x + + @eqx.filter_jit + def apply(f, x): + return f(x) - A very important special case is to trace all JAX arrays and treat all other - objects as static. + f(jnp.array(1), jnp.array(2)) # both args traced + f(jnp.array(1), 2) # first arg traced, second arg static + f(1, 2) # both args static - This is accomplished with `filter_spec=equinox.is_array`, which is the default. - (It is relatively unusual to need different behaviour to this.) + g(1, jnp.array(2)) # first arg static, second arg traced + g(1, 2) # both args static + + h(1) # traced + h(jnp.array(1)) # traced + h("hi") # not a trace-able JAX type, so error + + apply(lambda x: x + 1, jnp.array(1)) # first arg static, second arg traced. + ``` """ + if fun is sentinel: + return ft.partial( + filter_jit, + default=default, + fn=fn, + args=args, + kwargs=kwargs, + out=out, + **jitkwargs + ) + if any( x in jitkwargs for x in ("static_argnums", "static_argnames", "donate_argnums") ): @@ -98,34 +203,84 @@ def filter_jit( "'donate_argnums'." ) - # We choose not to make a distinction between ([arg, ..., arg], kwargs) and ((arg, ..., arg), kwargs) - if ( - isinstance(filter_spec, tuple) - and len(filter_spec) == 2 - and isinstance(filter_spec[0], list) - ): - filter_spec = (tuple(filter_spec[0]), filter_spec[1]) + if kwargs is None: + kwargs = {} - dynamic_fun, static_fun = partition(fun, filter_spec_fun) + # Original names are to provide a nice API, but they're too ambiguous for the code. + filter_default = default + filter_fn = fn + filter_args = args + filter_kwargs = kwargs + filter_out = out + del default, fn, args, kwargs, out - @ft.wraps(fun) - def fun_wrapper(*args, **kwargs): - inner_fun, dynamic, static_treedef, static_leaves = _process_args( - args, kwargs, dynamic_fun, static_fun, filter_spec - ) - dynamic_out, static_out = _f_wrapped_cache(inner_fun, **jitkwargs)( - dynamic, static_treedef, static_leaves, filter_spec_return - ) - return combine(dynamic_out, static_out.value) + signature = inspect.signature(fun) - def lower(*args, **kwargs): - inner_fun, dynamic, static_treedef, static_leaves = _process_args( - args, kwargs, dynamic_fun, static_fun, filter_spec + # Backward compatibility + filter_spec = jitkwargs.pop("filter_spec", is_array) + filter_spec_return = jitkwargs.pop("filter_spec_return", is_array) + filter_spec_fun = jitkwargs.pop("filter_spec_fun", is_array) + if any( + x is not is_array for x in (filter_spec, filter_spec_return, filter_spec_fun) + ): + # Old API + + warnings.warn( + "`filter_spec` is deprecated in favour of the new `args`, `kwargs` interface." ) - return _f_wrapped_cache(inner_fun, **jitkwargs).lower( - dynamic, static_treedef, static_leaves, filter_spec_return + if ( + any(x is not is_array for x in (filter_default, filter_fn, filter_out)) + or filter_args != () + or filter_kwargs != {} + ): + raise ValueError( + "Cannot use deprecated `filter_spec` at the same time as the new `args`, `kwargs` interface." + ) + + # We choose not to make a distinction between ([arg, ..., arg], kwargs) and ((arg, ..., arg), kwargs) + if ( + isinstance(filter_spec, tuple) + and len(filter_spec) == 2 + and isinstance(filter_spec[0], list) + ): + filter_spec = (tuple(filter_spec[0]), filter_spec[1]) + + filter_fn = filter_spec_fun + filter_out = filter_spec_return + new_style = False + else: + # New API + + signature_default = signature.replace( + parameters=[ + p + if p.kind + in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) + else p.replace(default=filter_default) + for p in signature.parameters.values() + ] ) + filter_bound = signature_default.bind_partial(*filter_args, **filter_kwargs) + filter_bound.apply_defaults() + filter_spec = (filter_bound.args, filter_bound.kwargs) + new_style = True + # ~Backward compatibility - fun_wrapper.lower = lower + unwrapped_fun = filter(strip_wrapped_partial(fun), filter_fn, inverse=True) + dynamic_fun, static_fun_leaves, static_fun_treedef = hashable_partition( + fun, filter_fn + ) + cached = _filter_jit_cache(unwrapped_fun, **jitkwargs) - return fun_wrapper + jit_wrapper = _JitWrapper( + _new_style=new_style, + _signature=signature, + _dynamic_fun=dynamic_fun, + _static_fun_treedef=static_fun_treedef, + _static_fun_leaves=static_fun_leaves, + _filter_default=filter_default, + _filter_spec=filter_spec, + _filter_out=filter_out, + _cached=cached, + ) + return module_update_wrapper(jit_wrapper, fun) diff --git a/equinox/module.py b/equinox/module.py index db5cd05c..b93e660e 100644 --- a/equinox/module.py +++ b/equinox/module.py @@ -2,6 +2,7 @@ import functools as ft import inspect from dataclasses import dataclass, field, fields +from typing import Type import jax @@ -10,8 +11,9 @@ def static_field(**kwargs): - """Used for marking that a field should _not_ be treated as part of the PyTree - of a [`equinox.Module`][]. (And is instead just treated as extra metadata.) + """Used for marking that a field should _not_ be treated as a leaf of the PyTree + of a [`equinox.Module`][]. (And is instead treated as part of the structure, i.e. + as extra metadata.) !!! example @@ -20,9 +22,10 @@ class MyModule(equinox.Module): normal_field: int static_field: int = equinox.static_field() - mymodule = MyModule() + mymodule = MyModule("normal", "static") leaves, treedef = jax.tree_flatten(mymodule) - assert len(leaves) == 1 + assert leaves == ["normal"] + assert "static" in str(treedef) ``` In practice this should rarely be used; it is usually preferential to just filter @@ -55,9 +58,23 @@ def __get__(self, instance, owner): return jax.tree_util.Partial(self.method, instance) +def _not_magic(k: str) -> bool: + return not (k.startswith("__") and k.endswith("__")) + + @ft.lru_cache(maxsize=128) -def _make_initable(cls): - field_names = {field.name for field in fields(cls)} +def _make_initable(cls: Type["Module"], wraps: bool) -> Type["Module"]: + if wraps: + field_names = { + "__module__", + "__name__", + "__qualname__", + "__doc__", + "__annotations__", + "__wrapped__", + } + else: + field_names = {field.name for field in fields(cls)} class _InitableModule(cls): pass @@ -75,16 +92,12 @@ def __setattr__(self, name, value): return _InitableModule -def _has_dataclass_init(cls): +def _has_dataclass_init(cls: Type["Module"]) -> bool: if "__init__" in cls.__dict__: return False return cls._has_dataclass_init -def _not_magic(k): - return not (k.startswith("__") and k.endswith("__")) - - # Inherits from abc.ABCMeta as a convenience for a common use-case. # It's not a feature we use ourselves. class _ModuleMeta(abc.ABCMeta): @@ -112,10 +125,12 @@ def __call__(cls, *args, **kwargs): self = cls.__new__(cls, *args, **kwargs) # Defreeze it during __init__ - initable_cls = _make_initable(cls) + initable_cls = _make_initable(cls, wraps=False) object.__setattr__(self, "__class__", initable_cls) - cls.__init__(self, *args, **kwargs) - object.__setattr__(self, "__class__", cls) + try: + cls.__init__(self, *args, **kwargs) + finally: + object.__setattr__(self, "__class__", cls) missing_names = { field.name @@ -254,3 +269,19 @@ def tree_unflatten(cls, aux, dynamic_field_values): for name, value in zip(static_field_names, static_field_values): object.__setattr__(self, name, value) return self + + +# Modifies in-place, just like functools.update_wrapper +def module_update_wrapper(wrapper: Module, wrapped) -> Module: + cls = wrapper.__class__ + initable_cls = _make_initable(cls, wraps=True) + object.__setattr__(wrapper, "__class__", initable_cls) + try: + # updated = ("__dict__",) is the default, but that's a bit much. + # It's common/possible for wrapper and wrapped to both be classes + # implementing __call__, in which case copying __dict__ over basically + # just breaks the wrapper class. + ft.update_wrapper(wrapper, wrapped, updated=()) + finally: + object.__setattr__(wrapper, "__class__", cls) + return wrapper diff --git a/equinox/nn/__init__.py b/equinox/nn/__init__.py index 7fcd6578..d901dbaa 100644 --- a/equinox/nn/__init__.py +++ b/equinox/nn/__init__.py @@ -14,4 +14,5 @@ from .embedding import Embedding from .linear import Identity, Linear from .normalisation import LayerNorm +from .pool import AvgPool1D, AvgPool2D, AvgPool3D, MaxPool1D, MaxPool2D, MaxPool3D, Pool from .rnn import GRUCell, LSTMCell diff --git a/equinox/nn/attention.py b/equinox/nn/attention.py index fcefabe6..7a06a352 100644 --- a/equinox/nn/attention.py +++ b/equinox/nn/attention.py @@ -105,6 +105,7 @@ def __init__( use_value_bias: bool = False, use_output_bias: bool = False, dropout_p: float = 0.0, + inference: bool = False, *, key: "jax.random.PRNGKey", **kwargs, @@ -126,6 +127,10 @@ def __init__( - `use_value_bias`: Whether to use a bias term in the value projections. - `use_output_bias`: Whether to use a bias term in the output projection. - `dropout_p`: Dropout probability on attention weights. + - `inference`: Whether to actually apply dropout at all. If `True` then dropout + is not applied. If `False` then dropout is applied. This may be toggled + with [`equinox.tree_inference`][] or overridden during + [`equinox.nn.MultiheadAttention.__call__`][]. - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter initialisation. (Keyword only argument.) """ @@ -155,7 +160,7 @@ def __init__( self.output_proj = Linear( num_heads * vo_size, output_size, use_bias=use_output_bias, key=okey ) - self.dropout = Dropout(dropout_p) + self.dropout = Dropout(dropout_p, inference=inference) self.num_heads = num_heads self.query_size = query_size @@ -179,6 +184,7 @@ def __call__( ] = None, *, key: Optional["jax.random.PRNGKey"] = None, + inference: Optional[bool] = None, deterministic: Optional[bool] = None, ) -> Array["query_seq_length", "output_size"]: # noqa: F821 """**Arguments:** @@ -193,8 +199,9 @@ def __call__( JAX array of shape `(num_heads, query_seq_length, kv_seq_length)`. - `key`: A `jax.random.PRNGKey` used for dropout. Unused if `dropout = 0`. (Keyword only argument.) - - `deterministic`: As [`equinox.nn.Dropout.__call__`][]. (Keyword only + - `inference`: As [`equinox.nn.Dropout.__call__`][]. (Keyword only argument.) + - `deterministic`: (Deprecated in favour of `inference`.) **Returns:** @@ -224,7 +231,9 @@ def __call__( logits = jnp.where(mask, logits, -jnp.inf) weights = jax.nn.softmax(logits, axis=-1) - weights = self.dropout(weights, key=key, deterministic=deterministic) + weights = self.dropout( + weights, key=key, inference=inference, deterministic=deterministic + ) attn = jnp.einsum("hsS,Shd->shd", weights, value_heads) attn = attn.reshape(query_seq_length, -1) diff --git a/equinox/nn/dropout.py b/equinox/nn/dropout.py index b6a120cf..1405db6b 100644 --- a/equinox/nn/dropout.py +++ b/equinox/nn/dropout.py @@ -27,29 +27,10 @@ def __init__( - `p`: The fraction of entries to set to zero. (On average.) - `inference`: Whether to actually apply dropout at all. If `True` then dropout - is *not* applied. If `False` then dropout is applied. + is *not* applied. If `False` then dropout is applied. This may be toggled + with [`equinox.tree_inference`][] or overridden during + [`equinox.nn.Dropout.__call__`][]. - `deterministic`: Deprecated alternative to `inference`. - - !!! info - - The `inference` flag is provided as it is common to only apply dropout - during training, but not to apply it during inference. If you want to change - this flag between training and inference, then you can either: - - - Override it with the `__call__`-time `inference` flag, see below. - - Modify the `inference` flag directly -- possible to do because `Dropout` - is just a PyTree. For example this sets all `inference` flags to - `True`: - ```python - model = ... # some model featuring Dropout/BatchNorm/etc. layers. - - def find_inference(m): - has_inference = lambda x: hasattr(x, "inference") - leaves = jax.tree_leaves(m, is_leaf=has_inference) - return tuple(k.inference for k in leaves if has_inference(k)) - - model = eqx.tree_at(find_inference, model, replace_fn=lambda _: True) - ``` """ if deterministic is not None: diff --git a/equinox/nn/pool.py b/equinox/nn/pool.py new file mode 100644 index 00000000..4d6b91f9 --- /dev/null +++ b/equinox/nn/pool.py @@ -0,0 +1,258 @@ +from typing import Callable, Optional, Sequence, Tuple, Union + +import jax.lax as lax +import jax.numpy as jnp +import jax.random +import numpy as np + +from ..custom_types import Array +from ..module import Module, static_field + + +class Pool(Module): + """General N-dimensional downsampling over a sliding window.""" + + init: Union[int, float, Array] + operation: Callable[[Array, Array], Array] + num_spatial_dims: int = static_field() + kernel_size: Union[int, Sequence[int]] = static_field() + stride: Union[int, Sequence[int]] = static_field() + padding: Union[int, Sequence[int], Sequence[Tuple[int, int]]] = static_field() + + def __init__( + self, + init: Union[int, float, Array], + operation: Callable[[Array, Array], Array], + num_spatial_dims: int, + kernel_size: Union[int, Sequence[int]], + stride: Union[int, Sequence[int]] = 1, + padding: Union[int, Sequence[int], Sequence[Tuple[int, int]]] = 0, + **kwargs, + ): + """**Arguments:** + - `init': The initial value for the reduction. + - `operation`: The operation applied to the inputs of each window. + - `num_spatial_dims`: The number of spatial dimensions. + - `kernel_size`: The size of the convolutional kernel. + - `stride`: The stride of the convolution. + - `padding`: The amount of padding to apply before and after each + spatial dimension. + + !!! info + + In order for `Pool' to be differentiable, `operation(init, x) == x' needs to + be true for all finite `x'. For further details see + https://www.tensorflow.org/xla/operation_semantics#reducewindow and + https://github.com/google/jax/issues/7718. + + """ + super().__init__(**kwargs) + + self.operation = operation + self.init = init + self.num_spatial_dims = num_spatial_dims + + if isinstance(kernel_size, int): + self.kernel_size = (kernel_size,) * num_spatial_dims + elif isinstance(kernel_size, Sequence): + self.kernel_size = kernel_size + else: + raise ValueError( + "`kernel_size` must either be an int or tuple of length " + f"{num_spatial_dims} containing ints." + ) + + if isinstance(stride, int): + self.stride = (stride,) * num_spatial_dims + elif isinstance(stride, Sequence): + self.stride = stride + else: + raise ValueError( + "`stride` must either be an int or tuple of length " + f"{num_spatial_dims} containing ints." + ) + + if isinstance(padding, int): + self.padding = tuple((padding, padding) for _ in range(num_spatial_dims)) + elif isinstance(padding, Sequence) and len(padding) == num_spatial_dims: + if all(isinstance(element, Sequence) for element in padding): + self.padding = padding + else: + self.padding = tuple((p, p) for p in padding) + else: + raise ValueError( + "`padding` must either be an int or tuple of length " + f"{num_spatial_dims} containing ints or tuples of length 2." + ) + + def __call__( + self, x: Array, *, key: Optional["jax.random.PRNGKey"] = None + ) -> Array: + """**Arguments:** + - `x`: The input. Should be a JAX array of shape `(channels, dim_1, ..., dim_N)`, where + `N = num_spatial_dims`. + - `key`: Ignored; provided for compatibility with the rest of the Equinox API. + (Keyword only argument.) + **Returns:** + A JAX array of shape `(channels, new_dim_1, ..., new_dim_N)`. + """ + assert len(x.shape) == self.num_spatial_dims + 1, ( + f"Input should have {self.num_spatial_dims} spatial dimensions, " + f"but input has shape {x.shape}" + ) + + x = jnp.moveaxis(x, 0, -1) + x = jnp.expand_dims(x, axis=0) + x = lax.reduce_window( + x, + self.init, + self.operation, + (1,) + self.kernel_size + (1,), + (1,) + self.stride + (1,), + ((0, 0),) + self.padding + ((0, 0),), + ) + + x = jnp.squeeze(x, axis=0) + x = jnp.moveaxis(x, -1, 0) + return x + + +class AvgPool1D(Pool): + """One-dimensional downsample using an average over a sliding window.""" + + def __init__( + self, + kernel_size, + stride=None, + padding=0, + **kwargs, + ): + super().__init__( + init=0, + operation=lax.add, + num_spatial_dims=1, + kernel_size=kernel_size, + stride=stride, + padding=padding, + **kwargs, + ) + + def __call__( + self, x: Array, *, key: Optional["jax.random.PRNGKey"] = None + ) -> Array: + return super().__call__(x) / np.prod(self.kernel_size) + + +class MaxPool1D(Pool): + """One-dimensional downsample using the maximum over a sliding window.""" + + def __init__( + self, + kernel_size, + stride=None, + padding=0, + **kwargs, + ): + super().__init__( + init=-jnp.inf, + operation=lax.max, + num_spatial_dims=1, + kernel_size=kernel_size, + stride=stride, + padding=padding, + **kwargs, + ) + + +class AvgPool2D(Pool): + """Two-dimensional downsample using an average over a sliding window.""" + + def __init__( + self, + kernel_size, + stride=None, + padding=0, + **kwargs, + ): + super().__init__( + init=0, + operation=lax.add, + num_spatial_dims=2, + kernel_size=kernel_size, + stride=stride, + padding=padding, + **kwargs, + ) + + def __call__( + self, x: Array, *, key: Optional["jax.random.PRNGKey"] = None + ) -> Array: + return super().__call__(x) / np.prod(self.kernel_size) + + +class MaxPool2D(Pool): + """Two-dimensional downsample using the maximum over a sliding window.""" + + def __init__( + self, + kernel_size, + stride=None, + padding=0, + **kwargs, + ): + super().__init__( + init=-jnp.inf, + operation=lax.max, + num_spatial_dims=2, + kernel_size=kernel_size, + stride=stride, + padding=padding, + **kwargs, + ) + + +class AvgPool3D(Pool): + """Three-dimensional downsample using an average over a sliding window.""" + + def __init__( + self, + kernel_size, + stride=None, + padding=0, + **kwargs, + ): + super().__init__( + init=0, + operation=lax.add, + num_spatial_dims=3, + kernel_size=kernel_size, + stride=stride, + padding=padding, + **kwargs, + ) + + def __call__( + self, x: Array, *, key: Optional["jax.random.PRNGKey"] = None + ) -> Array: + return super().__call__(x) / np.prod(self.kernel_size) + + +class MaxPool3D(Pool): + """Three-dimensional downsample using the maximum over a sliding window.""" + + def __init__( + self, + kernel_size, + stride=None, + padding=0, + **kwargs, + ): + super().__init__( + init=-jnp.inf, + operation=lax.max, + num_spatial_dims=3, + kernel_size=kernel_size, + stride=stride, + padding=padding, + **kwargs, + ) diff --git a/equinox/pretty_print.py b/equinox/pretty_print.py index 1650af94..3a54a364 100644 --- a/equinox/pretty_print.py +++ b/equinox/pretty_print.py @@ -8,11 +8,11 @@ import jax.numpy as jnp import numpy as np -from .custom_types import Array +from .custom_types import Array, PyTree Dataclass = Any -PrettyPrintAble = Any +PrettyPrintAble = PyTree _comma_sep = pp.concat([pp.text(","), pp.brk()]) @@ -182,14 +182,14 @@ def tree_pformat( ) -> str: """Pretty-formats a PyTree as a string, whilst abbreviating JAX arrays. + (This is the function used in `__repr__` of [`equinox.Module`][].) + All JAX arrays in the PyTree are condensed down to a short string representation of their dtype and shape. - (This is the function used in `__repr__` of [`equinox.Module`][].) - !!! example - A 32-bit floating-point JAX array of shape `(3, 4)` is printed as `f32[3,4]`. + A 32-bit floating-point JAX array of shape `(3, 4)` is printed as `f32[3,4]`. **Arguments:** diff --git a/equinox/serialisation.py b/equinox/serialisation.py new file mode 100644 index 00000000..c1b0f31a --- /dev/null +++ b/equinox/serialisation.py @@ -0,0 +1,180 @@ +import pathlib +from typing import Any, Callable, Union + +import jax +import jax.numpy as jnp +import numpy as np + +from . import experimental +from .custom_types import PyTree + + +def _default_serialise_filter_spec(f, x): + if isinstance(x, jnp.ndarray): + jnp.save(f, x) + elif isinstance(x, np.ndarray): + np.save(f, x) + elif isinstance(x, (bool, float, complex, int)): + np.save(f, x) + elif isinstance(x, experimental.StateIndex): + value, _, _ = x.unsafe_get() + jnp.save(f, value) + else: + pass + + +def _default_deserialise_filter_spec(f, x): + if isinstance(x, jnp.ndarray): + return jnp.load(f) + elif isinstance(x, np.ndarray): + return np.load(f) + elif isinstance(x, (bool, float, complex, int)): + return np.load(f).item() + elif isinstance(x, experimental.StateIndex): + value = jnp.load(f) + experimental.set_state(x, value) + return x + else: + return x + + +def _with_suffix(path): + path = pathlib.Path(path) + if path.suffix == "": + return path.with_suffix(".eqx") + else: + return path + + +def _assert_same(new, old): + if type(new) is not type(old): + raise RuntimeError(...) + if isinstance(new, (np.ndarray, jnp.ndarray)) and ( + new.shape != old.shape or new.dtype != old.dtype + ): + raise RuntimeError(...) + + +def _is_index(x): + return isinstance(x, experimental.StateIndex) + + +def tree_serialise_leaves( + path: Union[str, pathlib.Path], + pytree: PyTree, + filter_spec=_default_serialise_filter_spec, + is_leaf: Callable[[Any], bool] = _is_index, +) -> None: + """Save the leaves of a PyTree to file. + + **Arguments:** + + - `path`: The file location to save values to. + - `pytree`: The PyTree whose leaves will be saved. + - `filter_spec`: Specifies how to save each kind of leaf. By default all JAX + arrays, NumPy arrays, Python bool/int/float/complexes are saved, + [`equinox.experimental.StateIndex`][] instances have their value looked up + and saved, and all other leaf types are ignored. + - `is_leaf`: Called on every node of `pytree`; if `True` then this node will be + treated as a leaf. + + **Returns:** + + Nothing. + + !!! example + + This can be used to save a model to file. + + ```python + import equinox as eqx + import jax.random as jr + + model = eqx.nn.MLP(2, 2, 2, 2, key=jr.PRNGKey(0)) + eqx.tree_serialise_leaves("some_filename.eqx", model) + ``` + + !!! info + + `filter_spec` should typically be a function `(File, Any) -> None`, which takes + a file handle and a leaf to save, and either saves the leaf to the file or + does nothing. + + It can also be a PyTree of such functions, in which case the PyTree structure + should be a prefix of `pytree`, and each function will be mapped over the + corresponding sub-PyTree of `pytree`. + """ + + with open(_with_suffix(path), "wb") as f: + + def _serialise(spec, x): + def __serialise(y): + spec(f, y) + + jax.tree_map(__serialise, x, is_leaf=is_leaf) + + jax.tree_map(_serialise, filter_spec, pytree) + + +def tree_deserialise_leaves( + path: Union[str, pathlib.Path], + like: PyTree, + filter_spec=_default_deserialise_filter_spec, + is_leaf: Callable[[Any], bool] = _is_index, +) -> PyTree: + """Load the leaves of a PyTree from a file. + + **Arguments:** + + - `path`: The file location to load values from. + - `like`: A PyTree of the same structure, and with leaves of the same type, as the + PyTree being loaded. Those leaves which are loaded will replace the + corresponding leaves of `like`. + - `filter_spec`: Specifies how to load each kind of leaf. By default all JAX + arrays, NumPy arrays, Python bool/int/float/complexes are loaded, and + [`equinox.experimental.StateIndex`][] instances have their value looked up + and stored, and all other leaf types are not loaded (and will retain their + value from `like`). + - `is_leaf`: Called on every node of `like`; if `True` then this node will be + treated as a leaf. + + **Returns:** + + The loaded PyTree, formed by iterating over `like` and replacing some of its leaves + with the leaves saved in `path`. + + !!! example + + This can be used to load a model from file. + + ```python + import equinox as eqx + import jax.random as jr + + model = eqx.nn.MLP(2, 2, 2, 2, key=jr.PRNGKey(0)) + eqx.tree_serialise_leaves("some_filename.eqx", model) + model2 = eqx.tree_deserialise_leaves("some_filename.eqx", model) + ``` + + !!! info + + `filter_spec` should typically be a function `(File, Any) -> Any`, which takes + a file handle and a leaf from `like`, and either returns the corresponding + loaded leaf, or retuns the leaf from `like` unchanged. + + It can also be a PyTree of such functions, in which case the PyTree structure + should be a prefix of `pytree`, and each function will be mapped over the + corresponding sub-PyTree of `pytree`. + """ + + with open(_with_suffix(path), "rb") as f: + + def _deserialise(spec, x): + def __deserialise(y): + return spec(f, y) + + return jax.tree_map(__deserialise, x, is_leaf=is_leaf) + + out = jax.tree_map(_deserialise, filter_spec, like) + jax.tree_map(_assert_same, out, like, is_leaf=is_leaf) + return out diff --git a/equinox/tree.py b/equinox/tree.py index aca9fe2d..513acace 100644 --- a/equinox/tree.py +++ b/equinox/tree.py @@ -1,41 +1,73 @@ -from typing import Any, Callable, Sequence, Union +from typing import Any, Callable, Optional, Sequence, Union import jax import jax.numpy as jnp import numpy as np -from .custom_types import PyTree +from .custom_types import PyTree, sentinel +from .doc_utils import doc_repr -_sentinel = object() +_Node = doc_repr(Any, "Node") -_Leaf = Any + +class _LeafWrapper: + def __init__(self, value: Any): + self.value = value + + +def _remove_leaf_wrapper(x: _LeafWrapper) -> Any: + assert type(x) is _LeafWrapper + return x.value + + +class _CountedIdDict: + def __init__(self, keys, values): + assert len(keys) == len(values) + self._dict = {id(k): v for k, v in zip(keys, values)} + self._count = {id(k): 0 for k in keys} + + def __contains__(self, item): + return id(item) in self._dict + + def __getitem__(self, item): + self._count[id(item)] += 1 + return self._dict[id(item)] + + def get(self, item, default): + try: + return self[item] + except KeyError: + return default + + def count(self, item): + return self._count[id(item)] def tree_at( - where: Callable[[PyTree], Union[_Leaf, Sequence[_Leaf]]], + where: Callable[[PyTree], Union[_Node, Sequence[_Node]]], pytree: PyTree, - replace: Union[_Leaf, Sequence[_Leaf]] = _sentinel, - replace_fn: Callable[[_Leaf], _Leaf] = _sentinel, - is_leaf: Callable[[_Leaf], bool] = None, -) -> PyTree: + replace: Union[Any, Sequence[Any]] = sentinel, + replace_fn: Callable[[_Node], Any] = sentinel, + is_leaf: Optional[Callable[[Any], bool]] = None, +): """Updates a PyTree out-of-place; a bit like using `.at[].set()` on a JAX array. **Arguments:** - - `where`: A callable `PyTree -> Leaf` or `PyTree -> Sequence[Leaf]`. It should - consume a PyTree with the same structure as `pytree`, and return the leaf or - leaves that should be replaced. For example + - `where`: A callable `PyTree -> Node` or `PyTree -> Sequence[Node]`. It should + consume a PyTree with the same structure as `pytree`, and return the node or + nodes that should be replaced. For example `where = lambda mlp: mlp.layers[-1].linear.weight`. - `pytree`: The PyTree to modify. - `replace`: Either a single element, or a sequence of the same length as returned by `where`. This specifies the replacements to make at the locations specified by `where`. Mutually exclusive with `replace_fn`. - - `replace_fn`: A function `Leaf -> Any`. It will be called on every leaf specified + - `replace_fn`: A function `Node -> Any`. It will be called on every node specified by `where`. The return value from `replace_fn` will be used in its place. Mutually exclusive with `replace`. - - `is_leaf`: As `jax.tree_flatten`; used to determine what should be treated as a - leaf. + - `is_leaf`: As `jax.tree_flatten`. For example pass `is_leaf=lambda x: x is None` + to be able to replace `None` values using `tree_at`. **Returns:** @@ -44,69 +76,137 @@ def tree_at( !!! example This can be used to help specify the weights of a model to train or not to - train: + train. For example the following will train only the weight of the final linear + layer of an MLP: ```python - model = ... + def loss(model, ...): + ... + + model = eqx.nn.MLP(...) trainable = jax.tree_map(lambda _: False, model) trainable = equinox.tree_at(lambda mlp: mlp.layers[-1].linear.weight, model, replace=True) - equinox.filter_grad(..., filter_spec=trainable) - ``` - - !!! example - - Sub-PyTrees can be replaced by flattening them to leaves first: - - ```python - equinox.tree_at(lambda t: jax.tree_leaves(t.subtree), pytree, - jax.tree_leaves(new_subtree)) + grad_loss = equinox.filter_grad(loss, arg=trainable) + grads = grad_loss(model) ``` """ - if (replace is _sentinel and replace_fn is _sentinel) or ( - replace is not _sentinel and replace_fn is not _sentinel + # We need to specify a particular node in a PyTree. + # This is surprisingly difficult to do! As far as I can see, pretty much the only + # way of doing this is to specify e.g. `x.foo[0].bar` via `is`, and then pulling + # a few tricks to try and ensure that the same object doesn't appear multiple + # times in the same PyTree. + # + # So this first `tree_map` serves a dual purpose. + # 1) Makes a copy of the composite nodes in the PyTree, to avoid aliasing via + # e.g. `pytree=[(1,)] * 5`. This has the tuple `(1,)` appear multiple times. + # 2) It makes each leaf be a unique Python object, as it's wrapped in + # `_LeafWrapper`. This is needed because Python caches a few builtin objects: + # `assert 0 + 1 is 1`. I think only a few leaf types are subject to this. + # So point 1) should ensure that all composite nodes are unique Python objects, + # and point 2) should ensure that all leaves are unique Python objects. + # Between them, all nodes of `pytree` are handled. + # + # I think pretty much the only way this can fail is when using a custom node with + # singleton-like flatten+unflatten behaviour, which is pretty edge case. And we've + # added a check for it at the bottom of this function, just to be sure. + # + # Whilst we're here: we also double-check that `where` is well-formed and doesn't + # use leaf information. (As else `node_or_nodes` will be wrong.) + node_or_nodes_nowrapper = where(pytree) + pytree = jax.tree_map(_LeafWrapper, pytree, is_leaf=is_leaf) + node_or_nodes = where(pytree) + leaves1, structure1 = jax.tree_flatten(node_or_nodes_nowrapper, is_leaf=is_leaf) + leaves2, structure2 = jax.tree_flatten(node_or_nodes) + leaves2 = [_remove_leaf_wrapper(x) for x in leaves2] + if ( + structure1 != structure2 + or len(leaves1) != len(leaves2) + or any(l1 is not l2 for l1, l2 in zip(leaves1, leaves2)) ): raise ValueError( - "Precisely one of `replace` and `replace_fn` must be specified." + "`where` must use just the PyTree structure of `pytree`. `where` must not " + "depend on the leaves in `pytree`." ) - elif replace is _sentinel: - replace_passed = False - replacer = lambda j, i: replace_fn(flat[i]) + del node_or_nodes_nowrapper, leaves1, structure1, leaves2, structure2 + + # Normalise whether we were passed a single node or a sequence of nodes. + in_pytree = False + + def _in_pytree(x): + nonlocal in_pytree + if x is node_or_nodes: # noqa: F821 + in_pytree = True + + jax.tree_map(_in_pytree, pytree, is_leaf=lambda x: x is node_or_nodes) # noqa: F821 + if in_pytree: + nodes = (node_or_nodes,) + if replace is not sentinel: + replace = (replace,) + else: + nodes = node_or_nodes + del in_pytree, node_or_nodes + + # Normalise replace vs replace_fn + if replace is sentinel: + if replace_fn is sentinel: + raise ValueError( + "Precisely one of `replace` and `replace_fn` must be specified." + ) + else: + + def _replace_fn(x): + x = jax.tree_map(_remove_leaf_wrapper, x) + return replace_fn(x) + + replace_fns = [_replace_fn] * len(nodes) else: - replace_passed = True - replacer = lambda j, i: replace[j] - - # TODO: is there a neater way of accomplishing this? - flat, treedef = jax.tree_flatten(pytree, is_leaf=is_leaf) - flat_indices = list(range(len(flat))) - index_pytree = jax.tree_unflatten(treedef, flat_indices) - index = where(index_pytree) - # where can return either a single entry, or a sequence - if isinstance(index, int): - index = (index,) - replace = (replace,) - elif isinstance(index, Sequence): - for i in index: - if not isinstance(i, int): + if replace_fn is sentinel: + if len(nodes) != len(replace): raise ValueError( - r"""`where` must return a sequence of only leaves; not some subtree. - - If you want to replace all of a subtree, you can do so by replacing - >>> eqx.tree_at(lambda t: t.subtree, tree, new_subtree) # buggy - with - >>> eqx.tree_at(lambda t: jax.tree_leaves(t.subtree), tree, - ... jax.tree_leaves(new_subtree)) # fixed - """ + "`where` must return a sequence of leaves of the same length as " + "`replace`." ) + replace_fns = [lambda _, r=r: r for r in replace] + else: + raise ValueError( + "Precisely one of `replace` and `replace_fn` must be specified." + ) + node_replace_fns = _CountedIdDict(nodes, replace_fns) - if replace_passed and len(index) != len(replace): - raise ValueError( - "`where` must return a sequence of leaves of the same length as `replace`." - ) - for j, i in enumerate(index): - flat[i] = replacer(j, i) + # Actually do the replacement + def _make_replacement(x: _Node) -> Any: + return node_replace_fns.get(x, _remove_leaf_wrapper)(x) + + out = jax.tree_map( + _make_replacement, pytree, is_leaf=lambda x: x in node_replace_fns + ) + + # Check that `where` is well-formed. + for node in nodes: + count = node_replace_fns.count(node) + if count == 0: + raise ValueError( + "`where` does not specify an element or elements of `pytree`." + ) + elif count == 1: + pass + else: + raise ValueError( + "`where` does not uniquely identify a single element of `pytree`. This " + "usually occurs when trying to replace a `None` value:\n" + "\n" + " >>> eqx.tree_at(lambda t: t[0], (None, None, 1), True)\n" + "\n" + "\n" + "for which the fix is to specify that `None`s should be treated as " + "leaves:\n" + "\n" + " >>> eqx.tree_at(lambda t: t[0], (None, None, 1), True,\n" + " ... is_leaf=lambda x: x is None)" + ) - return jax.tree_unflatten(treedef, flat) + return out def tree_equal(*pytrees: PyTree) -> bool: @@ -125,6 +225,7 @@ def tree_equal(*pytrees: PyTree) -> bool: """ flat, treedef = jax.tree_flatten(pytrees[0]) array_types = (jnp.ndarray, np.ndarray) + out = True for pytree in pytrees[1:]: flat_, treedef_ = jax.tree_flatten(pytree) if treedef_ != treedef: @@ -136,9 +237,12 @@ def tree_equal(*pytrees: PyTree) -> bool: (type(elem) != type(elem_)) or (elem.shape != elem_.shape) or (elem.dtype != elem_.dtype) - or (elem != elem_).any() ): return False + allsame = (elem == elem_).all() + if allsame is False: + return False + out = out & allsame else: return False else: @@ -147,4 +251,55 @@ def tree_equal(*pytrees: PyTree) -> bool: else: if elem != elem_: return False - return True + return out + + +def _has_inference(leaf): + return hasattr(leaf, "inference") + + +def _inferences(pytree): + return tuple( + x.inference + for x in jax.tree_leaves(pytree, is_leaf=_has_inference) + if _has_inference(x) + ) + + +def tree_inference(pytree: PyTree, value: bool) -> PyTree: + """Convenience function for setting all `inference` attributes on a PyTree. + + Equivalent to: + ```python + has_inference = lambda leaf: hasattr(leaf, "inference") + + def where(pytree): + return tuple(x.inference + for x in jax.tree_leaves(pytree, is_leaf=has_inference) + if has_inference(x)) + + equinox.tree_at(where, pytree, replace_fn=lambda _: value) + ``` + + `inference` flags are used to toggle the behaviour of a number of the pre-built + neural network layers, such as [`equinox.nn.Dropout`][] or + [`equinox.experimental.BatchNorm`][]. + + **Arguments:** + + - `pytree`: the PyTree to modify. + - `value`: the value to set all `inference` attributes to. + + **Returns:** + + A copy of `pytree` with all `inference` flags set to `value`. + """ + + # For the sake of equinox.experimental.StateIndex. This won't defend against anyone + # setting inference flags manually using tree_at etc., but it should help overall. + if isinstance(jnp.array(0) + 1, jax.core.Tracer): + raise RuntimeError( + "inference flags should not be set whilst jit'ing, vmap'ing etc." + ) + + return tree_at(_inferences, pytree, replace_fn=lambda _: value) diff --git a/equinox/vmap_pmap.py b/equinox/vmap_pmap.py new file mode 100644 index 00000000..dac15eb7 --- /dev/null +++ b/equinox/vmap_pmap.py @@ -0,0 +1,599 @@ +import functools as ft +import inspect +from typing import Any, Callable, Dict, Union + +import jax +import jax.interpreters.batching as batching +import jax.numpy as jnp + +from .compile_utils import ( + compile_cache, + hashable_combine, + hashable_partition, + Static, + strip_wrapped_partial, +) +from .custom_types import BoolAxisSpec, PyTree, ResolvedBoolAxisSpec, sentinel +from .doc_utils import doc_strip_annotations +from .filters import combine, filter, is_array, partition +from .module import Module, module_update_wrapper + + +ResolvedMapAxisSpec = Union[None, int] +MapAxisSpec = Union[ResolvedMapAxisSpec, Callable[[Any], ResolvedMapAxisSpec]] +# +ResolvedAxisSpec = Union[ResolvedBoolAxisSpec, ResolvedMapAxisSpec] +AxisSpec = Union[ResolvedAxisSpec, Callable[[Any], ResolvedAxisSpec]] + + +def _is_none(x: Any) -> bool: + return x is None + + +def _resolve_axis(axis_spec: AxisSpec, elem: Any) -> PyTree[ResolvedAxisSpec]: + if axis_spec is None or isinstance(axis_spec, (bool, int)): + return axis_spec + if callable(axis_spec): + return jax.tree_map(axis_spec, elem) + else: + raise ValueError( + "`in_axes` and `out_axes` must consist of None, bools, ints, and callables only." + ) + + +def _resolve_axes( + pytree: PyTree[Any], axes_spec: PyTree[AxisSpec] +) -> PyTree[ResolvedAxisSpec]: + return jax.tree_map(_resolve_axis, axes_spec, pytree, is_leaf=_is_none) + + +def _jit_axis(axis: ResolvedAxisSpec) -> BoolAxisSpec: # not necessarily resolved + if isinstance(axis, bool): + return axis + elif isinstance(axis, int): + return True + elif axis is None: + return is_array + else: + assert False + + +def _map_axis(axis: ResolvedAxisSpec) -> ResolvedMapAxisSpec: + if isinstance(axis, bool): + return None + elif isinstance(axis, int): + return axis + elif axis is None: + return None + else: + assert False + + +def _jit_axes(axes: PyTree[ResolvedAxisSpec]) -> PyTree[BoolAxisSpec]: + return jax.tree_map(_jit_axis, axes, is_leaf=_is_none) + + +def _map_axes(axes: PyTree[ResolvedAxisSpec]) -> PyTree[ResolvedMapAxisSpec]: + return jax.tree_map(_map_axis, axes, is_leaf=_is_none) + + +class _VmapFilter: + def __init__(self, axis: AxisSpec): + self.axis = axis + + +_have_monkey_patched = False + + +def _monkey_patch(): + global _have_monkey_patched + if not _have_monkey_patched: + _have_monkey_patched = True + + _old_from_elt = batching.from_elt + + def from_elt(trace, axis_size, x, spec): + if isinstance(spec, _VmapFilter): + spec = _resolve_axis(spec.axis, x) + spec = _map_axis(spec) + return _old_from_elt(trace, axis_size, x, spec) + + batching.from_elt = from_elt + batching.spec_types.add(_VmapFilter) + + +def _zero_if_array_else_none(x: Any) -> ResolvedMapAxisSpec: + return 0 if is_array(x) else None + + +class _VmapWrapper(Module): + _signature: inspect.Signature + _fun: Callable + _default: AxisSpec + _fn: PyTree[AxisSpec] + _args: PyTree[AxisSpec] + _kwargs: PyTree[AxisSpec] + _out: PyTree[AxisSpec] + _callable_out_axes: bool + _vmapkwargs: Dict[str, Any] + + def __call__(__self, *args, **kwargs): + def _fun_wrapper(_fun, _args, _kwargs): + result = _fun(*_args, **_kwargs) + out_axes = _resolve_axes(result, __self._out) + out_axes = _map_axes(out_axes) + out_axes_is_none = jax.tree_map(_is_none, out_axes, is_leaf=_is_none) + nonvmapd, vmapd = partition(result, out_axes_is_none) + return vmapd, Static(nonvmapd) + + bound = __self._signature.bind(*args, **kwargs) + del args, kwargs + bound.apply_defaults() + _args = __self._args + (__self._default,) * ( + len(bound.args) - len(__self._args) + ) + _kwargs = { + key: __self._kwargs.get(key, __self._default) for key in bound.kwargs + } + in_axes = _resolve_axes( + (__self._fun, bound.args, bound.kwargs), (__self._fn, _args, _kwargs) + ) + in_axes = _map_axes(in_axes) + if __self._callable_out_axes: # `out` of type AxisSpec + out_axes = (jax.tree_map(_VmapFilter, __self._out), None) + else: # `out` of type ResolvedAxisSpec + out_axes = _map_axes(__self._out) + vmapd, nonvmapd = jax.vmap( + _fun_wrapper, in_axes=in_axes, out_axes=out_axes, **__self._vmapkwargs + )(__self._fun, bound.args, bound.kwargs) + return combine(vmapd, nonvmapd.value) + + def __get__(self, instance, owner): + if instance is None: + return self + return jax.tree_util.Partial(self, instance) + + +# Note the use of AxisSpec rather than MapAxisSpec. +# This is to support seamlessly switching out filter_pmap for filter_vmap. +@doc_strip_annotations +def filter_vmap( + fun: Callable = sentinel, + *, + default: AxisSpec = _zero_if_array_else_none, + fn: PyTree[AxisSpec] = None, + args: PyTree[AxisSpec] = (), + kwargs: PyTree[AxisSpec] = None, + # `out` default would ideally be _zero_if_array_else_none but that hits + # experimental behaviour, so it's not a good default. + # As a bonus, this also keeps the default the same as filter_pmap. + out: PyTree[AxisSpec] = 0, + **vmapkwargs +) -> Callable: + """Wraps together [`equinox.partition`][] and `jax.vmap`. + + !!! info + + By default, all JAX arrays are vectorised down their leading axis (i.e. axis + index 0), and all other types are not vectorised. + + **Arguments:** + + In each of the following cases, then `int` indicates an array axis to vectorise + over, `None` indicates that an argument should be broadcast (not vectorised + over), and functions `Leaf -> Union[None, int]` are mapped and evaluated on every + leaf of their subtree. + + (This is the same semantics as `jax.vmap(in_axes=..., out_axes=...)`.) + + `None` may be used for non-JAX-array arguments. It is an error to try and specify + an integer axis for a non-JAX-array. + + - `fun` is a pure function to vectorise. + - `default` should be a `Union[None, int]` or a function + `Leaf -> Union[None, int]`, and is applied by default to every argument and + keyword argument to `fun`. + - `args` is an optional per-argument override for `default`, and should be a tuple + of PyTrees with leaves that are either `Union[None, int]`s or functions + `Leaf -> Union[None, int]`. The PyTree structures should be prefixes of the + corresponding input to `fun`. + - `kwargs` is an optional per-keyword-argument override for `default` and should be + a dictionary, whose keys are the names of arguments to `fun`, and whose values + are PyTrees with leaves that either `Union[None, int]`s or functions + `Leaf -> Union[None, int]`. The PyTree structures should be prefixes of the + corresponding input to `fun`. + - `out` should be a PyTree with leaves that either `Union[None, int]`s or functions + `Leaf -> Union[None, int]`. The PyTree structure should be a prefix of the + output of `fun`. + - `fn` should be a PyTree with leaves that either `Union[None, int]`s or functions + `Leaf -> Union[None, int]`. The PyTree structure should be a prefix of `fun` + itself. (Note that `fun` may be any callable, e.g. a bound method, or a class + implementing `__call__`, and doesn't have to be a normal Python function.) + - `**vmapkwargs` are any other keyword arguments to `jax.vmap`. + + When `args`, `kwargs`, `out`, `fn` are prefixes of the corresponding input, their + value will be mapped over the input PyTree. + + **Returns:** + + The vectorised version of `fun`. + + !!! info + + In fact, besides `None`, `int` and `Leaf -> Union[None, int]`, then boolean + types are also supported, and treated identically to `None`. This is to support + seamlessly switching between [`equinox.filter_pmap`][] and + [`equinox.filter_vmap`][] if desired. + + !!! warning + + Using functions `Leaf -> Union[None, int]` in `out` is considered experimental, + and may change. + + !!! example + + ```python + import equinox as eqx + import jax.numpy as jnp + + @eqx.filter_vmap + def f(x, y): + return x + y + + @eqx.filter_vmap(kwargs=dict(x=1)) + def g(x, y): + return x + y + + @eqx.filter_vmap(args=(None,)) + def h(x, y): + return x + y + + f(jnp.array([1, 2]), jnp.array([3, 4])) # both args vectorised down axis 0 + f(jnp.array([1, 2]), 3) # first arg vectorised down axis 0 + # second arg broadcasted + + g(jnp.array([[1, 2]]), jnp.array([3, 4])) # first arg vectorised down axis 1 + # second arg vectorised down axis 0 + + h(jnp.array(1), jnp.array([2, 3])) # first arg broadcasted + # second arg vectorised down axis 0 + ``` + + !!! example + + `filter_vmap` can be used to easily create ensembles of models. For example, here's an + ensemble of eight MLPs: + + ```python + import equinox as eqx + import jax.random as jr + + key = jr.PRNGKey(0) + keys = jr.split(key, 8) + + # Create an ensemble of models + + @eqx.filter_vmap(out=lambda x: 0 if eqx.is_array(x) else None) + def make_ensemble(key): + return eqx.nn.MLP(2, 2, 2, 2, key=key) + + mlp_ensemble = make_ensemble(keys) + + # Evaluate each member of the ensemble on the same data + + @eqx.filter_vmap(kwargs=dict(x=None)) + def evaluate_ensemble(model, x): + return model(x) + + evaluate_ensemble(mlp_ensemble, jr.normal(key, (2,))) + + # Evaluate each member of the ensemble on different data + + @eqx.filter_vmap + def evaluate_per_ensemble(model, x): + return model(x) + + evaluate_per_ensemble(mlp_ensemble, jr.normal(key, (8, 2))) + ``` + + Here, `make_ensemble` works because [`equinox.nn.MLP`][] is a PyTree, and so it + is a valid output from a `filter_vmap`. This PyTree includes some JAX arrays + (the weights and biases) and some non-JAX-arrays (e.g. activation functions). + `filter_vmap` will vectorise the JAX arrays (with separate weights for each + member of the ensemble) whilst leaving the non-JAX-arrays alone. + + Note that as the weights in `mlp_ensemble` now have a leading batch dimension + -- that the weights of `eqx.nn.MLP` instances do not typically have -- then it + cannot be called directly. It must instead be passed back into a vectorised + region to be called. + """ + + if fun is sentinel: + return ft.partial( + filter_vmap, + default=default, + fn=fn, + args=args, + kwargs=kwargs, + out=out, + **vmapkwargs + ) + + if kwargs is None: + kwargs = {} + + signature = inspect.signature(fun) + + signature_default = signature.replace( + parameters=[ + p + if p.kind + in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) + else p.replace(default=default) + for p in signature.parameters.values() + ] + ) + bound = signature_default.bind_partial(*args, **kwargs) + del args, kwargs + bound.apply_defaults() + + if any(callable(o) for o in jax.tree_leaves(out)): + # Experimental behaviour + _monkey_patch() + callable_out_axes = True + else: + callable_out_axes = False + + vmap_wrapper = _VmapWrapper( + _signature=signature, + _fun=fun, + _default=default, + _fn=fn, + _args=bound.args, + _kwargs=bound.kwargs, + _out=out, + _callable_out_axes=callable_out_axes, + _vmapkwargs=vmapkwargs, + ) + return module_update_wrapper(vmap_wrapper, fun) + + +@compile_cache +def _filter_pmap_cache(unwrapped_fun, **pmapkwargs): + @ft.partial(jax.pmap, **pmapkwargs) + @ft.wraps(unwrapped_fun) + def fun_wrapped(dynamic, static_leaves, static_treedef, jit_out_axes): + _fun, _args, _kwargs, _maybe_dummy = hashable_combine( + dynamic, static_leaves, static_treedef + ) + del _maybe_dummy + _out = _fun(*_args, **_kwargs) + _dynamic, _static = partition(_out, jit_out_axes) + return _dynamic, Static(_static) + + return fun_wrapped + + +class _PmapWrapper(Module): + _signature: inspect.Signature + _fun: Callable + _default: AxisSpec + _fn: PyTree[AxisSpec] + _args: PyTree[AxisSpec] + _kwargs: PyTree[AxisSpec] + _out: PyTree[AxisSpec] + _unwrapped_fun: Any + _pmapkwargs: Dict[str, Any] + + def _fun_wrapper(self, is_lower, args, kwargs): + bound = self._signature.bind(*args, **kwargs) + del args, kwargs + bound.apply_defaults() + _args = self._args + (self._default,) * (len(bound.args) - len(self._args)) + _kwargs = {key: self._kwargs.get(key, self._default) for key in bound.kwargs} + try: + axis_size = self._pmapkwargs["axis_size"] + except KeyError: + maybe_dummy = 0 # hashable non-array object + else: + # Work around JAX bug #9252 + maybe_dummy = jnp.empty(axis_size) + in_axes = _resolve_axes( + (self._fun, bound.args, bound.kwargs, maybe_dummy), + (self._fn, _args, _kwargs, _zero_if_array_else_none), + ) + jit_in_axes = _jit_axes(in_axes) + map_in_axes = _map_axes(in_axes) + jit_out_axes = _jit_axes(self._out) + map_out_axes = _map_axes(self._out) + + cached = _filter_pmap_cache( + self._unwrapped_fun, + in_axes=(map_in_axes, None, None), + out_axes=(map_out_axes, None), + static_broadcasted_argnums=(1, 2, 3), + **self._pmapkwargs + ) + + dynamic, static_leaves, static_treedef = hashable_partition( + (self._fun, bound.args, bound.kwargs, maybe_dummy), jit_in_axes + ) + if is_lower: + return cached.lower(dynamic, static_leaves, static_treedef, jit_out_axes) + else: + dynamic_out, static_out = cached( + dynamic, static_leaves, static_treedef, jit_out_axes + ) + return combine(dynamic_out, static_out.value) + + def __call__(__self, *args, **kwargs): + return __self._fun_wrapper(False, args, kwargs) + + def lower(__self, *args, **kwargs): + return __self._fun_wrapper(True, args, kwargs) + + def __get__(self, instance, owner): + if instance is None: + return self + return jax.tree_util.Partial(self, instance) + + +@doc_strip_annotations +def filter_pmap( + fun: Callable = sentinel, + axis_name=None, + *, + default: AxisSpec = _zero_if_array_else_none, + fn: PyTree[AxisSpec] = None, + args: PyTree[AxisSpec] = (), + kwargs: PyTree[AxisSpec] = None, + out: PyTree[AxisSpec] = 0, + **pmapkwargs +) -> Callable: + """Wraps together [`equinox.partition`][] and `jax.pmap`. + + !!! info + + By default, the computation is parallelised by splitting all JAX arrays down + their leading axis (i.e. axis index 0), and broadcasting all other types to + each replica. + + **Arguments:** + + In each of the following cases, then `int` indicates an array axis to split down, + `None` indicates that an argument should be broadcast to each device (not split + across devices), and functions `Leaf -> Union[None, bool, int]` are mapped and + evaluated on every leaf of their subtree. + + Note that `jax.pmap`, and thus `equinox.filter_pmap`, also JIT-compile their + function in the same way as `jax.jit`. By default, all JAX arrays are traced and + all other arrays are treated as static inputs. This may be controlled explicitly + -- instead of just passing `None` -- by passing either `True` (traced) or + `False` (static). + + (For `None` and `int`, this is the same semantics as + `jax.vmap(in_axes=..., out_axes=...)`.) + + `None`, `False` and `True` may be used for non-JAX-array arguments. It is an error + to try and specify an integer axis for a non-JAX-array. + + - `fun` is a pure function to parallelise. + - `default` should be a `Union[None, bool, int]` or a function + `Leaf -> Union[None, bool, int]`, and is applied by default to every argument + and keyword argument to `fun`. + - `args` is an optional per-argument override for `default`, and should be a tuple + of PyTrees with leaves that are either `Union[None, bool, int]`s or functions + `Leaf -> Union[None, bool, int]`. The PyTree structures should be prefixes of + the corresponding input to `fun`. + - `kwargs` is an optional per-keyword-argument override for `default` and should be + a dictionary, whose keys are the names of arguments to `fun`, and whose values + are PyTrees with leaves that either `Union[None, bool, int]`s or functions + `Leaf -> Union[None, bool, int]`. The PyTree structures should be prefixes of + the corresponding input to `fun`. + - `out` should be a PyTree with leaves that either `Union[None, bool, int]`s or + functions `Leaf -> Union[None, bool, int]`. The PyTree structure should be a + prefix of the output of `fun`. `True` indicates a tracer, `False` indicates any + auxiliary information to return. + - `fn` should be a PyTree with leaves that either `Union[None, bool, int]`s or + functions `Leaf -> Union[None, bool, int]`. The PyTree structure should be a + prefix of `fun` itself. (Note that `fun` may be any callable, e.g. a bound + method, or a class implementing `__call__`, and doesn't have to be a normal + Python function.) + - `**pmapkwargs` are any other keyword arguments to `jax.pmap`. + + When `args`, `kwargs`, `out`, `fn` are prefixes of the corresponding input, their + value will be mapped over the input PyTree. + + **Returns:** + + The parallelised version of `fun`. + + !!! example + + ```python + import equinox as eqx + import jax.numpy as jnp + + @eqx.filter_pmap + def f(x, y): + return x + y + + @eqx.filter_pmap(kwargs=dict(x=1)) + def g(x, y): + return x + y + + @eqx.filter_pmap(args=(None,)) + def h(x, y): + return x + y + + @eqx.filter_pmap + def apply(fun, x): + return fun(x) + + f(jnp.array([1, 2]), jnp.array([3, 4])) # both args split down axis 0 + f(jnp.array([1, 2]), 3) # first arg split down axis 0 + # second arg broadcasted + + g(jnp.array([[1, 2]]), jnp.array([3, 4])) # first arg split down axis 1 + # second arg split down axis 0 + + h(jnp.array(1), jnp.array([2, 3])) # first arg broadcasted + # second arg split down axis 0 + + apply(lambda x: x + 1, jnp.array([2, 3])) # first arg broadcasted (as it's not + # a JAX array) + # second arg split down axis 0 + ``` + """ + + if fun is sentinel: + return ft.partial( + filter_pmap, + axis_name=axis_name, + default=default, + fn=fn, + args=args, + kwargs=kwargs, + out=out, + **pmapkwargs + ) + + if kwargs is None: + kwargs = {} + + signature = inspect.signature(fun) + + signature_default = signature.replace( + parameters=[ + p + if p.kind + in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) + else p.replace(default=default) + for p in signature.parameters.values() + ] + ) + bound = signature_default.bind_partial(*args, **kwargs) + del args, kwargs + bound.apply_defaults() + + unwrapped_fun = filter(strip_wrapped_partial(fun), _jit_axes(fn), inverse=True) + + if any(callable(o) for o in jax.tree_leaves(out)): + # In practice we demand `out` be of type `PyTree[ResolvedAxisSpec]`. + raise NotImplementedError( + "`filter_pmap(out_axes=...)` does not support filter functions (only None, bool, int)" + ) + + pmap_wrapper = _PmapWrapper( + _signature=signature, + _fun=fun, + _default=default, + _fn=fn, + _args=bound.args, + _kwargs=bound.kwargs, + _out=out, + _unwrapped_fun=unwrapped_fun, + _pmapkwargs=pmapkwargs, + ) + + return module_update_wrapper(pmap_wrapper, fun) diff --git a/mkdocs.yml b/mkdocs.yml index 7cdd9e6b..d5be5ea3 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -102,6 +102,7 @@ nav: - Neural network layers: - 'api/nn/linear.md' - 'api/nn/conv.md' + - 'api/nn/pool.md' - 'api/nn/rnn.md' - 'api/nn/attention.md' - 'api/nn/dropout.md' @@ -113,8 +114,11 @@ nav: - 'api/filtering/filter-functions.md' - 'api/filtering/filtered-transformations.md' - Utilities: - - 'api/helpers.md' - - 'api/stateful.md' + - 'api/utilities/manipulation.md' + - 'api/utilities/serialisation.md' + - 'api/utilities/miscellaneous.md' + - Experimental: + - 'api/experimental/stateful.md' - Misc: - 'citation.md' - 'faq.md' diff --git a/tests/helpers.py b/tests/helpers.py new file mode 100644 index 00000000..4268ccdc --- /dev/null +++ b/tests/helpers.py @@ -0,0 +1,43 @@ +import functools as ft +import operator + +import jax +import jax.numpy as jnp +import numpy as np + + +def _shaped_allclose(x, y, **kwargs): + if type(x) is not type(y): + return False + if isinstance(x, jnp.ndarray): + if jnp.issubdtype(x.dtype, jnp.inexact): + return ( + x.shape == y.shape + and x.dtype == y.dtype + and jnp.allclose(x, y, **kwargs) + ) + else: + return x.shape == y.shape and x.dtype == y.dtype and jnp.all(x == y) + elif isinstance(x, np.ndarray): + if np.issubdtype(x.dtype, np.inexact): + return ( + x.shape == y.shape + and x.dtype == y.dtype + and np.allclose(x, y, **kwargs) + ) + else: + return x.shape == y.shape and x.dtype == y.dtype and np.all(x == y) + else: + return x == y + + +def shaped_allclose(x, y, **kwargs): + """As `jnp.allclose`, except: + - It also supports PyTree arguments. + - It mandates that shapes match as well (no broadcasting) + """ + same_structure = jax.tree_structure(x) == jax.tree_structure(y) + allclose = ft.partial(_shaped_allclose, **kwargs) + return same_structure and jax.tree_util.tree_reduce( + operator.and_, jax.tree_map(allclose, x, y), True + ) diff --git a/tests/test_filter_grad.py b/tests/test_filter_grad.py index 293ce027..02d3571f 100644 --- a/tests/test_filter_grad.py +++ b/tests/test_filter_grad.py @@ -1,4 +1,4 @@ -import functools as ft +from typing import Union import jax import jax.numpy as jnp @@ -9,22 +9,27 @@ import equinox as eqx -def test_filter_grad1(getkey): +@pytest.mark.parametrize("api_version", (0, 1)) +def test_filter_grad1(api_version, getkey): a = jrandom.normal(getkey(), (2, 3)) - @ft.partial(eqx.filter_grad, filter_spec=lambda _: True) def f(x): return jnp.sum(x) + if api_version == 0: + f = eqx.filter_grad(f, filter_spec=lambda _: True) + else: + f = eqx.filter_grad(arg=True)(f) + grad_f = f(a) assert jnp.all(grad_f == 1) -def test_filter_grad2(getkey): +@pytest.mark.parametrize("api_version", (0, 1)) +def test_filter_grad2(api_version, getkey): a = jrandom.normal(getkey(), (2, 3)) b = jrandom.normal(getkey(), (2, 3)) - @ft.partial(eqx.filter_grad, filter_spec=eqx.is_inexact_array) def f(x): sum = 0.0 for arg in jax.tree_leaves(x): @@ -32,6 +37,11 @@ def f(x): sum = sum + jnp.sum(arg) return sum + if api_version == 0: + f = eqx.filter_grad(f, filter_spec=eqx.is_inexact_array) + else: + f = eqx.filter_grad(arg=eqx.is_inexact_array)(f) + ga, gb = f([a, b]) assert jnp.all(ga == 1) assert jnp.all(gb == 1) @@ -60,23 +70,32 @@ def f(x): assert gnp is None -def test_filter_grad3(getkey): +@pytest.mark.parametrize("api_version", (0, 1)) +def test_filter_grad3(api_version, getkey): a = jrandom.normal(getkey(), (2, 3)) b = jrandom.normal(getkey(), (1, 2)) c = jrandom.normal(getkey(), ()) - @ft.partial(eqx.filter_grad, filter_spec=[True, False]) def f(x): return jnp.sum(x[0]) + jnp.sum(x[1]) + if api_version == 0: + f = eqx.filter_grad(f, filter_spec=[True, False]) + else: + f = eqx.filter_grad(arg=[True, False])(f) + ga, gb = f([a, b]) assert jnp.all(ga == 1) assert gb is None - @ft.partial(eqx.filter_grad, filter_spec={"a": True, "b": False}) def h(x, y): return jnp.sum(x["a"]) * jnp.sum(x["b"]) * y + if api_version == 0: + h = eqx.filter_grad(h, filter_spec={"a": True, "b": False}) + else: + h = eqx.filter_grad(arg={"a": True, "b": False})(h) + grad = h({"a": a, "b": b}, c) assert jnp.allclose(grad["a"], jnp.sum(b) * c) assert grad["b"] is None @@ -86,36 +105,112 @@ def h(x, y): # TODO: more comprehensive tests on this. -def test_filter_value_and_grad_(getkey): +@pytest.mark.parametrize("api_version", (0, 1)) +def test_filter_value_and_grad(api_version, getkey): a = jrandom.normal(getkey(), (2, 3)) - @ft.partial(eqx.filter_value_and_grad, filter_spec=eqx.is_inexact_array) def f(x): return jnp.sum(x) + if api_version == 0: + f = eqx.filter_value_and_grad(f, filter_spec=eqx.is_inexact_array) + else: + f = eqx.filter_value_and_grad(arg=eqx.is_inexact_array)(f) + val, grad = f(a) assert val == jnp.sum(a) assert jnp.all(grad == 1) -def test_aux(getkey): +@pytest.mark.parametrize("api_version", (0, 1)) +def test_aux(api_version, getkey): a = jrandom.normal(getkey(), (2, 3)) - @ft.partial(eqx.filter_grad, has_aux=True, filter_spec=eqx.is_inexact_array) def f(x): return jnp.sum(x), "hi" + if api_version == 0: + f = eqx.filter_grad(f, has_aux=True, filter_spec=eqx.is_inexact_array) + else: + f = eqx.filter_grad(has_aux=True, arg=eqx.is_inexact_array)(f) + grad, aux = f(a) assert aux == "hi" assert jnp.all(grad == 1) - @ft.partial( - eqx.filter_value_and_grad, has_aux=True, filter_spec=eqx.is_inexact_array - ) def f(x): return jnp.sum(x), "hi" + if api_version == 0: + f = eqx.filter_value_and_grad(f, has_aux=True, filter_spec=eqx.is_inexact_array) + else: + f = eqx.filter_value_and_grad(has_aux=True, arg=eqx.is_inexact_array)(f) + (value, aux), grad = f(a) assert value == jnp.sum(a) assert aux == "hi" assert jnp.all(grad == 1) + + +@pytest.mark.parametrize("call", [False, True]) +@pytest.mark.parametrize("outer", [False, True]) +def test_methods(call, outer): + class M(eqx.Module): + increment: Union[int, jnp.ndarray] + + if call: + + def __call__(self, x): + return x + self.increment + + if not outer: + __call__ = eqx.filter_grad(__call__) + else: + + def method(self, x): + return x + self.increment + + if not outer: + method = eqx.filter_grad(method) + + m = M(jnp.array(5.0)) + grad_m = M(jnp.array(1.0)) + y = jnp.array(1.0) + + if call: + if outer: + assert eqx.filter_grad(m)(y) == 1 + else: + assert m(y) == grad_m + else: + if outer: + assert eqx.filter_grad(m.method)(y) == 1 + else: + assert m.method(y) == grad_m + + +def test_grad_jit(): + num_traces = 0 + + @eqx.filter_custom_vjp + def f(x): + return x + + def f_fwd(x): + return x, None + + def f_bwd(_, g, __): + nonlocal num_traces + num_traces += 1 + return g + 2 + + f.defvjp(f_fwd, f_bwd) + x = jnp.array(1.0) + + jitf = jax.jit(f) + assert eqx.filter_grad(jitf)(x) == 3 + assert eqx.filter_grad(jitf)(x) == 3 + assert num_traces == 1 + assert eqx.filter_grad(eqx.filter_jit(f))(x) == 3 + assert eqx.filter_grad(eqx.filter_jit(f))(x) == 3 + assert num_traces == 2 diff --git a/tests/test_filter_jit.py b/tests/test_filter_jit.py index 50af9094..60fbda38 100644 --- a/tests/test_filter_jit.py +++ b/tests/test_filter_jit.py @@ -1,9 +1,11 @@ import functools as ft +from typing import Union import jax import jax.numpy as jnp import jax.random as jrandom import pytest +from helpers import shaped_allclose import equinox as eqx @@ -12,7 +14,8 @@ def _eq(a, b): return (type(a) is type(b)) and (a == b) -def test_filter_jit1(getkey): +@pytest.mark.parametrize("api_version", (0, 1)) +def test_filter_jit1(api_version, getkey): a = jrandom.normal(getkey(), (2, 3)) b = jrandom.normal(getkey(), (3,)) c = jrandom.normal(getkey(), (1, 4)) @@ -27,9 +30,17 @@ def test_filter_jit1(getkey): array_tree = [{"a": a, "b": b}, (c,)] _mlp = jax.tree_map(lambda u: u if eqx.is_array_like(u) else None, general_tree[-1]) - @ft.partial(eqx.filter_jit, filter_spec=lambda _: True) - def f(x): - return x + if api_version == 0: + + @ft.partial(eqx.filter_jit, filter_spec=lambda _: True) + def f(x): + return x + + else: + + @eqx.filter_jit(default=True) + def f(x): + return x assert jnp.all(a == f(a)) f1 = f(array_tree) @@ -40,10 +51,14 @@ def f(x): with pytest.raises(TypeError): f(general_tree) - @ft.partial(eqx.filter_jit, filter_spec=eqx.is_inexact_array) def g(x): return jax.tree_map(lambda u: u if eqx.is_array_like(u) else None, x) + if api_version == 0: + g = eqx.filter_jit(g, filter_spec=eqx.is_inexact_array) + else: + g = eqx.filter_jit(default=eqx.is_inexact_array)(g) + assert jnp.all(a == g(a)) g1 = g(array_tree) assert jnp.all(g1[0]["a"] == a) @@ -59,10 +74,14 @@ def g(x): assert jnp.all(g2[4] == c) assert _eq(g2[5], _mlp) - @ft.partial(eqx.filter_jit, filter_spec=eqx.is_array_like) def h(x): return jax.tree_map(lambda u: u if eqx.is_array_like(u) else None, x) + if api_version == 0: + h = eqx.filter_jit(h, filter_spec=eqx.is_array_like) + else: + h = eqx.filter_jit(h, default=eqx.is_array_like) + assert jnp.all(a == h(a)) h1 = h(array_tree) assert jnp.all(h1[0]["a"] == a) @@ -79,7 +98,8 @@ def h(x): assert _eq(h2[5], _mlp) -def test_filter_jit2(getkey): +@pytest.mark.parametrize("api_version", (0, 1)) +def test_filter_jit2(api_version, getkey): a = jrandom.normal(getkey(), (2, 3)) b = jrandom.normal(getkey(), (3,)) c = jrandom.normal(getkey(), (1, 4)) @@ -93,44 +113,49 @@ def test_filter_jit2(getkey): ] _mlp = jax.tree_map(lambda u: u if eqx.is_array_like(u) else None, general_tree[-1]) - @ft.partial( - eqx.filter_jit, - filter_spec=( - ( - [ - True, - True, - False, - {"a": True, "tuple": (False, True)}, - True, - eqx.is_inexact_array, - ], - ), - {}, - ), - ) + spec = [ + True, + True, + False, + {"a": True, "tuple": (False, True)}, + True, + eqx.is_inexact_array, + ] + def f(x): return jax.tree_map(lambda u: u if eqx.is_array_like(u) else None, x) - f1 = f(general_tree) - assert _eq(f1[0], jnp.array(1)) - assert _eq(f1[1], jnp.array(True)) - assert _eq(f1[2], None) - assert jnp.all(f1[3]["a"] == a) - assert _eq(f1[3]["tuple"][0], 2.0) - assert jnp.all(f1[3]["tuple"][1] == b) - assert jnp.all(f1[4] == c) - assert _eq(f1[5], _mlp) - - -def test_num_traces(): + if api_version == 0: + wrappers = [ft.partial(eqx.filter_jit, filter_spec=((spec,), {}))] + else: + wrappers = [eqx.filter_jit(args=(spec,)), eqx.filter_jit(kwargs=dict(x=spec))] + + for wrapper in wrappers: + _f = wrapper(f) + f1 = _f(general_tree) + assert _eq(f1[0], jnp.array(1)) + assert _eq(f1[1], jnp.array(True)) + assert _eq(f1[2], None) + assert jnp.all(f1[3]["a"] == a) + assert _eq(f1[3]["tuple"][0], 2.0) + assert jnp.all(f1[3]["tuple"][1] == b) + assert jnp.all(f1[4] == c) + assert _eq(f1[5], _mlp) + + +@pytest.mark.parametrize("api_version", (0, 1)) +def test_num_traces(api_version): num_traces = 0 - @ft.partial(eqx.filter_jit, filter_spec=lambda _: True) def f(x): nonlocal num_traces num_traces += 1 + if api_version == 0: + f = eqx.filter_jit(f, filter_spec=lambda _: True) + else: + f = eqx.filter_jit(default=True)(f) + f(jnp.zeros(2)) f(jnp.zeros(2)) assert num_traces == 1 @@ -146,24 +171,34 @@ def f(x): num_traces = 0 - @ft.partial(eqx.filter_jit, filter_spec=([eqx.is_array_like, False], {})) def g(x, y): nonlocal num_traces num_traces += 1 + if api_version == 0: + g = eqx.filter_jit(g, filter_spec=([eqx.is_array_like, False], {})) + else: + g = eqx.filter_jit(args=[eqx.is_array_like, False])(g) + g(jnp.zeros(2), True) g(jnp.zeros(2), False) assert num_traces == 2 num_traces = 0 - @ft.partial( - eqx.filter_jit, filter_spec=([False, {"a": True, "b": False}, False, False], {}) - ) def h(x, y, z, w): nonlocal num_traces num_traces += 1 + if api_version == 0: + h = eqx.filter_jit( + h, filter_spec=([False, {"a": True, "b": False}, False, False], {}) + ) + else: + h = eqx.filter_jit( + h, args=[False, {"a": True, "b": False}], kwargs=dict(z=False, w=False) + ) + h(True, {"a": 1, "b": 1}, True, True) h(False, {"a": 1, "b": 1}, True, True) h(True, {"a": 1, "b": 0}, True, True) @@ -174,35 +209,151 @@ def h(x, y, z, w): assert num_traces == 5 -def test_bound_method(): +@pytest.mark.parametrize("call", [False, True]) +@pytest.mark.parametrize("outer", [False, True]) +def test_methods(call, outer): num_traces = 0 class M(eqx.Module): - def method(self, x): - nonlocal num_traces - num_traces += 1 - return x + 1 + increment: Union[int, jnp.ndarray] + + if call: + + def __call__(self, x): + nonlocal num_traces + num_traces += 1 + return x + self.increment + + if not outer: + __call__ = eqx.filter_jit(__call__) + else: + + def method(self, x): + nonlocal num_traces + num_traces += 1 + return x + self.increment + + if not outer: + method = eqx.filter_jit(method) - m = M() y = jnp.array(1.0) - eqx.filter_jit(m.method)(y) - eqx.filter_jit(m.method)(y) + + def run(_m): + if call: + if outer: + return eqx.filter_jit(_m)(y) + else: + return _m(y) + else: + if outer: + return eqx.filter_jit(_m.method)(y) + else: + return _m.method(y) + + m = M(1) + assert run(m) == 2 + assert run(m) == 2 assert num_traces == 1 + n = M(2) + assert run(n) == 3 + assert run(n) == 3 + assert num_traces == 2 + o = M(jnp.array(1)) + p = M(jnp.array(2)) + assert run(o) == 2 + assert run(p) == 3 + assert num_traces == 3 -def test_callable_class(): +def test_args_kwargs(): num_traces = 0 - class M(eqx.Module): - def __call__(self, x): - nonlocal num_traces - num_traces += 1 - return x + 1 + @eqx.filter_jit(kwargs=dict(x=True)) + def f(*args, **kwargs): + nonlocal num_traces + num_traces += 1 + return kwargs["x"] - m = M() - y = jnp.array(1.0) - eqx.filter_jit(m)(y) - eqx.filter_jit(m)(y) + assert f(x=2) == 2 + assert f(x=3) == 3 + assert num_traces == 1 + + assert f(x=3, y=4) == 3 # check we can use other kwargs + assert num_traces == 2 + + @eqx.filter_jit(default=eqx.is_array_like) + def g(*args, **kwargs): + nonlocal num_traces + num_traces += 1 + return kwargs["x"] + + assert g(x=1, y=1) == 1 + assert g(x=1, y=2) == 1 + assert num_traces == 3 + + @eqx.filter_jit(args=(eqx.is_array,)) + def h(*args, **kwargs): + nonlocal num_traces + num_traces += 1 + return args[0] + + assert h(1, 2) == 1 # check we can use other args + + +def test_jit_jit(): + num_traces = 0 + + @eqx.filter_jit(default=True) + @eqx.filter_jit(default=True) + def f(x): + nonlocal num_traces + num_traces += 1 + return x + 1 + + assert f(1) == 2 + assert f(2) == 3 + assert num_traces == 1 + + @eqx.filter_jit(default=True) + def g(x): + nonlocal num_traces + num_traces += 1 + return x + 1 + + assert eqx.filter_jit(g, default=True)(1) == 2 + assert eqx.filter_jit(g, default=True)(2) == 3 + assert num_traces == 2 + + +def test_jit_grad(): + num_traces = 0 + + def f(x): + nonlocal num_traces + num_traces += 1 + return x + 1 + + assert eqx.filter_jit(eqx.filter_grad(f))(jnp.array(1.0)) == 1 + assert eqx.filter_jit(eqx.filter_grad(f))(jnp.array(2.0)) == 1 + assert num_traces == 1 + + assert eqx.filter_jit(eqx.filter_value_and_grad(f))(jnp.array(1.0)) == (2, 1) + assert eqx.filter_jit(eqx.filter_value_and_grad(f))(jnp.array(2.0)) == (3, 1) + assert num_traces == 2 + + +def test_jit_vmap(): + num_traces = 0 + + def f(x): + nonlocal num_traces + num_traces += 1 + return x + 1 + + out = eqx.filter_jit(eqx.filter_vmap(f))(jnp.array([1, 2])) + assert shaped_allclose(out, jnp.array([2, 3])) + out = eqx.filter_jit(eqx.filter_vmap(f))(jnp.array([2, 3])) + assert shaped_allclose(out, jnp.array([3, 4])) assert num_traces == 1 @@ -256,11 +407,11 @@ def the_test_function_name_value_and_grad(x): in warning_text ) - def wrapped_fun(x, y): - return x + y + def wrapped_fun(y): + pass def the_test_function_name(x, y): - return wrapped_fun(x, y) + return x + y fun = eqx.filter_jit( ft.wraps(wrapped_fun)(ft.partial(the_test_function_name, jnp.array(1.0))) diff --git a/tests/test_filter_pmap.py b/tests/test_filter_pmap.py new file mode 100644 index 00000000..bf14d669 --- /dev/null +++ b/tests/test_filter_pmap.py @@ -0,0 +1,241 @@ +import functools as ft +from typing import Union + +import jax +import jax.numpy as jnp +import pytest +from helpers import shaped_allclose as _shaped_allclose + +import equinox as eqx + + +(cpu,) = jax.devices("cpu") +filter_pmap = ft.partial(eqx.filter_pmap, devices=[cpu]) + + +def shaped_allclose(x, y, **kwargs): + if isinstance(x, jnp.ndarray): + x = jax.device_put(x) + return _shaped_allclose(x, y, **kwargs) + + +def _zero_if_inexact_array_else_none(x): + return 0 if eqx.is_inexact_array(x) else None + + +def test_args(): + @filter_pmap(args=(_zero_if_inexact_array_else_none, [{"a": None}], 0)) + def f(a, b, c, d): + return a + b[0]["a"] + c + d + + out = f(jnp.array([1]), [{"a": jnp.array([2])}], jnp.array([3]), 4) + assert shaped_allclose(out, jnp.array([[10]])) + + +def test_kwargs(): + @filter_pmap(kwargs=dict(a=_zero_if_inexact_array_else_none, b=[{"a": None}], c=0)) + def f(a, b, c, d): + return a + b[0]["a"] + c + d + + out = f(jnp.array([1]), [{"a": jnp.array([2])}], jnp.array([3]), 4) + assert shaped_allclose(out, jnp.array([[10]])) + + +def test_default(): + @filter_pmap(default=_zero_if_inexact_array_else_none) + def f(a, b): + return a + b + + assert shaped_allclose(f(jnp.array(3), jnp.array([3.0])), jnp.array([6.0])) + + with pytest.raises(ValueError): + assert shaped_allclose(f(jnp.array(3.0), jnp.array([3.0])), jnp.array([6.0])) + + +def test_fn(): + class M(eqx.Module): + increment: jnp.ndarray + + def __call__(self, x): + return x + self.increment + + m = M(jnp.array([1])) + o1 = filter_pmap(m, fn=0)(1) + o2 = filter_pmap(m, fn=0)(jnp.array([3])) + o3 = filter_pmap(m, default=None, fn=0)(jnp.array([3])) + assert shaped_allclose(o1, jnp.array([2])) + assert shaped_allclose(o2, jnp.array([4])) + assert shaped_allclose(o3, jnp.array([[4]])) + + +def test_out(): + def f(x): + return x + + o1 = filter_pmap(f, default=None, out=None, axis_size=1)(jnp.array([3, 4])) + o2 = filter_pmap(f, out=0, axis_size=1)(1) + o3 = filter_pmap(f, default=None, out=0, axis_size=1)(jnp.array([3, 4])) + + assert shaped_allclose(o1, jnp.array([3, 4])) + assert shaped_allclose(o2, jnp.array([1])) + assert shaped_allclose(o3, jnp.array([[3, 4]])) + + +def test_no_arrays(): + @filter_pmap(out=None, axis_size=1) + def f(x): + return x + + assert shaped_allclose(f(1), 1) + + +def test_bool(): + num_traces = 0 + + @filter_pmap(args=(True, False), axis_size=1) + def f(x, y): + nonlocal num_traces + num_traces += 1 + return x + y + + assert shaped_allclose(f(1, 2), jnp.array([3])) + assert num_traces == 1 + assert shaped_allclose(f(3, 2), jnp.array([5])) + assert num_traces == 1 + assert shaped_allclose(f(1, 3), jnp.array([4])) + assert num_traces == 2 + assert shaped_allclose(f(3, 3), jnp.array([6])) + assert num_traces == 2 + + +@pytest.mark.parametrize("call", [False, True]) +@pytest.mark.parametrize("outer", [False, True]) +def test_methods(call, outer): + num_traces = 0 + + class M(eqx.Module): + increment: Union[int, jnp.ndarray] + + if call: + + def __call__(self, x): + nonlocal num_traces + num_traces += 1 + return x + self.increment + + if not outer: + __call__ = eqx.filter_pmap(__call__) + else: + + def method(self, x): + nonlocal num_traces + num_traces += 1 + return x + self.increment + + if not outer: + method = eqx.filter_pmap(method) + + y = jnp.array([1]) + + def run(_m): + if call: + if outer: + return eqx.filter_pmap(_m)(y) + else: + return _m(y) + else: + if outer: + return eqx.filter_pmap(_m.method)(y) + else: + return _m.method(y) + + m = M(1) + assert shaped_allclose(run(m), jnp.array([2])) + assert shaped_allclose(run(m), jnp.array([2])) + assert num_traces == 1 + n = M(2) + assert shaped_allclose(run(n), jnp.array([3])) + assert shaped_allclose(run(n), jnp.array([3])) + assert num_traces == 2 + o = M(jnp.array([5])) + p = M(jnp.array([6])) + if outer: + assert shaped_allclose(run(o), jnp.array([[6]])) + assert shaped_allclose(run(p), jnp.array([[7]])) + else: + assert shaped_allclose(run(o), jnp.array([6])) + assert shaped_allclose(run(p), jnp.array([7])) + assert num_traces == 3 + + +def test_pmap_grad(): + num_traces = 0 + + def f(x): + nonlocal num_traces + num_traces += 1 + return x + 1 + + grad = eqx.filter_pmap(eqx.filter_grad(f))(jnp.array([1.0])) + assert shaped_allclose(grad, jnp.array([1.0])) + grad = eqx.filter_pmap(eqx.filter_grad(f))(jnp.array([2.0])) + assert shaped_allclose(grad, jnp.array([1.0])) + assert num_traces == 1 + + value, grad = eqx.filter_pmap(eqx.filter_value_and_grad(f))(jnp.array([1.0])) + assert shaped_allclose(value, jnp.array([2.0])) + assert shaped_allclose(grad, jnp.array([1.0])) + value, grad = eqx.filter_pmap(eqx.filter_value_and_grad(f))(jnp.array([2.0])) + assert shaped_allclose(value, jnp.array([3.0])) + assert shaped_allclose(grad, jnp.array([1.0])) + assert num_traces == 2 + + +def test_pmap_vmap(): + num_traces = 0 + + def f(x): + nonlocal num_traces + num_traces += 1 + return x + 1 + + out = eqx.filter_pmap(eqx.filter_vmap(f))(jnp.array([[1, 2]])) + assert shaped_allclose(out, jnp.array([[2, 3]])) + out = eqx.filter_pmap(eqx.filter_vmap(f))(jnp.array([[2, 3]])) + assert shaped_allclose(out, jnp.array([[3, 4]])) + assert num_traces == 1 + + +def test_args_kwargs(): + num_traces = 0 + + @eqx.filter_pmap(kwargs=dict(x=True), axis_size=1) + def f(*args, **kwargs): + nonlocal num_traces + num_traces += 1 + return kwargs["x"] + + assert f(x=2) == 2 + assert f(x=3) == 3 + assert num_traces == 1 + + assert f(x=3, y=4) == 3 # check we can use other kwargs + assert num_traces == 2 + + @eqx.filter_pmap(default=eqx.is_array_like, axis_size=1) + def g(*args, **kwargs): + nonlocal num_traces + num_traces += 1 + return kwargs["x"] + + assert g(x=1, y=1) == 1 + assert g(x=1, y=2) == 1 + assert num_traces == 3 + + @eqx.filter_pmap(args=(eqx.is_array,), axis_size=1) + def h(*args, **kwargs): + nonlocal num_traces + num_traces += 1 + return args[0] + + assert h(1, 2) == 1 # check we can use other args diff --git a/tests/test_filter_vmap.py b/tests/test_filter_vmap.py new file mode 100644 index 00000000..f9b295eb --- /dev/null +++ b/tests/test_filter_vmap.py @@ -0,0 +1,169 @@ +from typing import Union + +import jax.numpy as jnp +import jax.random as jr +import pytest +from helpers import shaped_allclose + +import equinox as eqx + + +def _zero_if_inexact_array_else_none(x): + return 0 if eqx.is_inexact_array(x) else None + + +def _zero_if_array_else_none(x): + return 0 if eqx.is_array(x) else None + + +def test_args(): + @eqx.filter_vmap(args=(_zero_if_inexact_array_else_none, [{"a": None}], 0)) + def f(a, b, c, d): + return a + b[0]["a"] + c + d + + out = f(jnp.array([1]), [{"a": jnp.array([2])}], jnp.array([3]), 4) + assert shaped_allclose(out, jnp.array([[10]])) + + +def test_kwargs(): + @eqx.filter_vmap( + kwargs=dict(a=_zero_if_inexact_array_else_none, b=[{"a": None}], c=0) + ) + def f(a, b, c, d): + return a + b[0]["a"] + c + d + + out = f(jnp.array([1]), [{"a": jnp.array([2])}], jnp.array([3]), 4) + assert shaped_allclose(out, jnp.array([[10]])) + + +def test_default(): + @eqx.filter_vmap(default=_zero_if_inexact_array_else_none) + def f(a, b): + return a + b + + assert shaped_allclose(f(jnp.array(3), jnp.array([3.0])), jnp.array([6.0])) + + with pytest.raises(ValueError): + assert shaped_allclose(f(jnp.array(3.0), jnp.array([3.0])), jnp.array([6.0])) + + +def test_fn(): + class M(eqx.Module): + increment: jnp.ndarray + + def __call__(self, x): + return x + self.increment + + m = M(jnp.array([1, 2])) + o1 = eqx.filter_vmap(m, fn=0)(1) + o2 = eqx.filter_vmap(m, fn=0)(jnp.array([3, 4])) + o3 = eqx.filter_vmap(m, default=None, fn=0)(jnp.array([3, 4])) + assert shaped_allclose(o1, jnp.array([2, 3])) + assert shaped_allclose(o2, jnp.array([4, 6])) + assert shaped_allclose(o3, jnp.array([[4, 5], [5, 6]])) + + +def test_out(): + def f(x): + return x + + o1 = eqx.filter_vmap(f, out=None, axis_size=5)(1) + o2 = eqx.filter_vmap(f, default=None, out=None, axis_size=5)(jnp.array([3, 4])) + o3 = eqx.filter_vmap(f, out=0, axis_size=5)(1) + o4 = eqx.filter_vmap(f, default=None, out=0, axis_size=5)(jnp.array([3, 4])) + + assert shaped_allclose(o1, 1) + assert shaped_allclose(o2, jnp.array([3, 4])) + assert shaped_allclose(o3, jnp.array([1, 1, 1, 1, 1])) + assert shaped_allclose(o4, jnp.array([[3, 4], [3, 4], [3, 4], [3, 4], [3, 4]])) + + +def test_no_arrays(): + @eqx.filter_vmap(out=_zero_if_inexact_array_else_none, axis_size=5) + def f(x): + return x + + assert shaped_allclose(f(1), 1) + + +def test_ensemble(getkey): + def make(key): + return eqx.nn.MLP(5, 4, 3, 2, key=getkey()) + + keys = jr.split(getkey(), 7) + models = eqx.filter_vmap(make, out=lambda x: 0 if eqx.is_array(x) else None)(keys) + + def call(model, x): + return model(x) + + xs1 = jr.normal(getkey(), (7, 5)) + assert eqx.filter_vmap(call)(models, xs1).shape == (7, 4) + assert eqx.filter_vmap(models, fn=_zero_if_array_else_none)(xs1).shape == (7, 4) + + xs2 = jr.normal(getkey(), (5,)) + assert eqx.filter_vmap(call, args=(_zero_if_array_else_none, None))( + models, xs2 + ).shape == (7, 4) + assert eqx.filter_vmap(models, default=None, fn=_zero_if_array_else_none,)( + xs2 + ).shape == (7, 4) + + +@pytest.mark.parametrize("call", [False, True]) +@pytest.mark.parametrize("outer", [False, True]) +def test_methods(call, outer): + class M(eqx.Module): + increment: Union[int, jnp.ndarray] + + if call: + + def __call__(self, x): + return x + self.increment + + if not outer: + __call__ = eqx.filter_vmap(__call__) + else: + + def method(self, x): + return x + self.increment + + if not outer: + method = eqx.filter_vmap(method) + + m = M(5) + y = jnp.array([1.0]) + + if call: + if outer: + assert eqx.filter_vmap(m)(y) == 6 + else: + assert m(y) == 6 + else: + if outer: + assert eqx.filter_vmap(m.method)(y) == 6 + else: + assert m.method(y) == 6 + + +def test_args_kwargs(): + @eqx.filter_vmap(kwargs=dict(x=0)) + def f(*args, **kwargs): + return kwargs["x"] + + # check we can use other kwargs + assert shaped_allclose(f(x=jnp.array([3]), y=4), jnp.array([3])) + assert shaped_allclose(f(x=jnp.array([3]), y=jnp.array([4])), jnp.array([3])) + + with pytest.raises(ValueError): + f(x=jnp.array([3]), y=jnp.array(4)) + + with pytest.raises(ValueError): + f(x=jnp.array(3)) + + @eqx.filter_vmap(args=(_zero_if_array_else_none,)) + def h(*args, **kwargs): + return args[0] + + # check we can use other args + assert h(1, jnp.array([2])) == 1 + assert shaped_allclose(h(jnp.array([2]), 3), jnp.array([2])) diff --git a/tests/test_filters.py b/tests/test_filters.py index 64854e46..e64209ef 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -1,3 +1,5 @@ +from typing import Any + import jax import jax.numpy as jnp import numpy as np @@ -109,7 +111,7 @@ def test_filter(getkey): ] filtered = eqx.filter(pytree, filter_spec=filter_spec) none_linear = jax.tree_map(lambda _: None, eqx.nn.Linear(1, 1, key=getkey())) - assert filtered[0] is None + assert filtered[0] == none_linear assert filtered[1] == pytree[1] assert filtered[2][0] == none_linear assert filtered[2][1] is sentinel @@ -139,3 +141,28 @@ def test_partition_and_combine(getkey): assert not isinstance(arg, int) assert eqx.combine(filtered, unfiltered) == pytree assert eqx.combine(unfiltered, filtered) == pytree + + +def test_partition_subtree(): + a, b = eqx.partition([(1,), 2], [True, False]) + eqx.combine(a, b) + + +def test_is_leaf(): + class M(eqx.Module): + value: Any + + def is_m(x): + return isinstance(x, M) + + def filter_spec(x): + if is_m(x): + return x.value == 1 + return True + + pytree = [M(1), M(2), 3] + out = eqx.filter(pytree, filter_spec, is_leaf=is_m) + assert out == [M(1), None, 3] + out1, out2 = eqx.partition(pytree, filter_spec, is_leaf=is_m) + assert out1 == [M(1), None, 3] + assert out2 == [None, M(2), None] diff --git a/tests/test_nn.py b/tests/test_nn.py index 62d5cb89..de4f47d1 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -1,5 +1,6 @@ import functools as ft import warnings +from typing import List, Union import jax import jax.numpy as jnp @@ -565,6 +566,14 @@ def test_batch_norm(getkey): with pytest.raises(RuntimeError): jax.vmap(bn, axis_name="batch")(x1) + # Test that it handles multiple axis_names + + bn = eqx.experimental.BatchNorm(6, ("batch1", "batch2")) + assert ( + jax.vmap(jax.vmap(bn, axis_name="batch1"), axis_name="batch2")(x2).shape + == x2.shape + ) + # Test that it normalises x1alt = jrandom.normal(jrandom.PRNGKey(5678), (10, 5)) # avoid flakey test @@ -645,3 +654,116 @@ def get_weights(m): conv = eqx.nn.Conv3d(5, 4, 3, key=getkey()) conv = eqx.tree_at(lambda c: c.weight, conv, replace_fn=spectral) assert conv(jrandom.normal(getkey(), (5, 8, 8, 8))).shape == (4, 6, 6, 6) + + +def test_maxpool1d(): + + x = jnp.arange(14).reshape(1, 14) + max_pool = eqx.nn.MaxPool1D(2, 3) + output = max_pool(x) + answer = jnp.array([1, 4, 7, 10, 13]) + + assert jnp.all(output == answer) + + +def test_avgpool1d(): + + x = jnp.arange(14).reshape(1, 14) + avg_pool = eqx.nn.AvgPool1D(2, 3) + output = avg_pool(x) + answer = jnp.array([0.5, 3.5, 6.5, 9.5, 12.5]) + + assert jnp.all(output == answer) + + +def test_maxpool2d(): + + x = jnp.arange(36).reshape(1, 6, 6) + max_pool = eqx.nn.MaxPool2D(2, (3, 2)) + output = max_pool(x) + answer = jnp.array([[7, 9, 11], [25, 27, 29]]) + + assert jnp.all(output == answer) + + +def test_avgpool2d(): + + x = jnp.arange(36).reshape(1, 6, 6) + avg_pool = eqx.nn.AvgPool2D((1, 3), 2) + output = avg_pool(x) + answer = jnp.array([[1, 3], [13, 15], [25, 27]]) + + assert jnp.all(output == answer) + + +def test_maxpool3d(): + + x = jnp.arange(64).reshape(1, 4, 4, 4) + max_pool = eqx.nn.MaxPool3D(2, (3, 2, 1)) + output = max_pool(x) + answer = jnp.array([[[21, 22, 23], [29, 30, 31]]]) + + assert jnp.all(output == answer) + + +def test_avgpool3d(): + + x = jnp.arange(64).reshape(1, 4, 4, 4) + avg_pool = eqx.nn.AvgPool3D((1, 3, 1), 2) + output = avg_pool(x) + answer = jnp.array([[[4, 6]], [[36, 38]]]) + + assert jnp.all(output == answer) + + +def test_poolpadding(): + x = jnp.arange(64).reshape(1, 4, 4, 4) + max_pool = eqx.nn.MaxPool3D(2, 1, ((0, 1), (0, 1), (0, 1))) + output = max_pool(x) + + assert output.shape == (1, 4, 4, 4) + + +def test_poolbackprop(): + def max_pool_mean(x): + max_pool = eqx.nn.MaxPool3D((2, 2, 2), (1, 1, 1), ((0, 1), (0, 1), (0, 1))) + return jnp.mean(max_pool(x)) + + x = jnp.arange(64, dtype=jnp.float32).reshape(1, 4, 4, 4) + grad_fn = jax.value_and_grad(max_pool_mean) + + grad_fn(x) + + +def test_poolnetworkbackprop(getkey): + class CNN(eqx.Module): + conv_layer: List[Union[eqx.nn.Conv2d, eqx.nn.MaxPool2D]] + linear_layers: List[eqx.nn.Linear] + + def __init__(self, key): + key1, key2, key3 = jax.random.split(key, 3) + self.conv_layer = [eqx.nn.Conv2d(3, 2, 3, key=key1), eqx.nn.MaxPool2D(2, 2)] + self.linear_layers = [ + eqx.nn.Linear(450, 256, key=key2), + eqx.nn.Linear(256, 10, key=key3), + ] + + def __call__(self, x): + for layer in self.conv_layer: + x = layer(x) + x = jnp.ravel(x) + for layer in self.linear_layers: + x = layer(x) + x = jax.nn.relu(x) + return x + + cnn = CNN(getkey()) + + @jax.vmap + @jax.value_and_grad + def loss_grad(x, y): + return jax.numpy.mean((y - cnn(x)) ** 2) + + x = jrandom.normal(getkey(), (10, 3, 32, 32)) + y = jrandom.normal(getkey(), (10, 10)) + loss_grad(x, y) diff --git a/tests/test_serialisation.py b/tests/test_serialisation.py new file mode 100644 index 00000000..3fd2e9dc --- /dev/null +++ b/tests/test_serialisation.py @@ -0,0 +1,62 @@ +import jax.numpy as jnp +import numpy as np + +import equinox as eqx + + +def test_leaf_serialisation(getkey, tmp_path): + jax_array1 = jnp.array(1) + jax_array2 = jnp.array([1.0, 2.0]) + numpy_array1 = np.array(1) + numpy_array2 = np.array([1.0, 2.0]) + scalars = (True, 1, 1.0, 1 + 1j) + index = eqx.experimental.StateIndex() + index_value = jnp.array(9) + eqx.experimental.set_state(index, index_value) + func = lambda x: x + obj = object() + tree = ( + jax_array1, + jax_array2, + numpy_array1, + numpy_array2, + scalars, + index, + func, + obj, + ) + + eqx.tree_serialise_leaves(tmp_path, tree) + + like_jax_array1 = jnp.array(5) + like_jax_array2 = jnp.array([6.0, 7.0]) + like_numpy_array1 = np.array(5) + like_numpy_array2 = np.array([6.0, 7.0]) + like_scalars = (False, 6, 6.0, 6 + 6j) + like_index = eqx.experimental.StateIndex() + eqx.experimental.set_state(index, jnp.array(6)) + like_func = lambda x: x + like_obj = object() + like = ( + like_jax_array1, + like_jax_array2, + like_numpy_array1, + like_numpy_array2, + like_scalars, + like_index, + like_func, + like_obj, + ) + + tree_loaded = eqx.tree_deserialise_leaves(tmp_path, like) + + tree_serialisable = tree[:-3] + tree_loaded_serialisable = tree_loaded[:-3] + tree_loaded_index, tree_loaded_func, tree_loaded_obj = tree_loaded[-3:] + assert eqx.tree_equal(tree_serialisable, tree_loaded_serialisable) + assert tree_loaded_index is like_index + assert jnp.array_equal( + eqx.experimental.get_state(like_index, jnp.array(4)), index_value + ) + assert tree_loaded_func is like_func + assert tree_loaded_obj is like_obj diff --git a/tests/test_stateful.py b/tests/test_stateful.py index e658c8d7..dda9aaaf 100644 --- a/tests/test_stateful.py +++ b/tests/test_stateful.py @@ -197,13 +197,12 @@ def test_inference_no_state(): eqx.experimental.get_state(index, jnp.array(1)) -@pytest.mark.skip def test_inference_not_set_under_jit(): index = eqx.experimental.StateIndex() @jax.jit def f(i): - eqx.tree_at(lambda j: j.inference, i, True) + eqx.tree_inference(i, True) with pytest.raises(RuntimeError): f(index) diff --git a/tests/test_tree.py b/tests/test_tree.py index b8718c04..113ba7d0 100644 --- a/tests/test_tree.py +++ b/tests/test_tree.py @@ -1,3 +1,5 @@ +import jax +import jax.nn as jnn import jax.numpy as jnp import jax.random as jrandom import pytest @@ -69,6 +71,45 @@ def replace_fn(x): eqx.tree_at(where, pytree, replace=(0, 1), replace_fn=replace_fn) +def test_tree_at_subtree(getkey): + class L(eqx.Module): + def __call__(self, x): + return x + + mlp = eqx.nn.MLP(2, 2, 2, 2, key=getkey()) + + # m.layers is a node in the PyTree + newmlp1 = eqx.tree_at( + lambda m: m.layers, mlp, [L() for _ in range(len(mlp.layers))] + ) + + # tuple(m.layers) is a sequence of nodes in the PyTree. + newmlp2 = eqx.tree_at( + lambda m: tuple(m.layers), mlp, [L() for _ in range(len(mlp.layers))] + ) + + x = jrandom.normal(getkey(), (2,)) + assert (jnn.relu(x) == newmlp1(x)).all() + assert (jnn.relu(x) == newmlp2(x)).all() + + +def test_tree_at_dependent_where(getkey): + mlp = eqx.nn.MLP(2, 2, 2, 2, key=getkey()) + + def where(m): + return jax.tree_leaves(eqx.filter(m, eqx.is_array)) + + with pytest.raises(ValueError): + eqx.tree_at(where, mlp, where(mlp)) + + +def test_tree_at_none_leaf(): + with pytest.raises(ValueError): + eqx.tree_at(lambda y: y[0], (None, None, 0), True) + x = eqx.tree_at(lambda y: y[0], (None, None, 0), True, is_leaf=lambda y: y is None) + assert x == (True, None, 0) + + def test_tree_equal(): key1 = jrandom.PRNGKey(0) key2 = jrandom.PRNGKey(1) @@ -85,3 +126,43 @@ def test_tree_equal(): assert not eqx.tree_equal(pytree1, pytree3) assert not eqx.tree_equal(pytree1, pytree4) assert not eqx.tree_equal(pytree1, pytree5) + + +def test_tree_equal_jit(): + a = jnp.array(0) + b = jnp.array(0) + + @jax.jit + def run1(): + assert not eqx.tree_equal(a, 0) + + run1() + + @jax.jit + def run2(): + return eqx.tree_equal(a, b) + + assert run2() + + @jax.jit + def run3(x, y): + return eqx.tree_equal(x, y) + + assert run3(a, b) + assert not run3(a, 1) + + +def test_tree_inference(getkey): + attention = eqx.nn.MultiheadAttention(2, 4, key=getkey()) + assert attention.dropout.inference is False + attention2 = eqx.tree_inference(attention, True) + assert attention.dropout.inference is False + assert attention2.dropout.inference is True + + @jax.jit + def f(): + dropout = eqx.nn.Dropout() + eqx.tree_inference(dropout, True) + + with pytest.raises(RuntimeError): + f()