Skip to content

Commit

Permalink
Version 0.5.0 (#82)
Browse files Browse the repository at this point in the history
* Improved filtered transformations (#71)

* Updated filter_{jit,grad,value_and_grad}

* Documented that BatchNorm supports multiple axis_names

* Doc fix

* added tree_inference

* Added filter_vmap and filter_pmap

* Made filter_{vmap,pmp} non-experimental. Fixed doc issues.

* Added 'inference' call-time argument to 'MultiheadAttention', which was missed when switching over 'Dropout' to use 'inference'.

* Added test for tree_inference

* Doc improvements

* Doc tweaks

* Updated tests

* Two minor bugfixes for filter_{jit,grad,vmap,pmap,value_and_grad}.

1. Applying filter_{jit,vmap,pmap} to a function with *args and **kwargs should now work. Previously inspect.Signature.apply_defaults was not filling these in.
2. The output of all filtered transformations has become a Module. This means that e.g. calling filter_jit(filter_grad(f)) multiple times will not induce any recompilation, as all filter_grad(f)s are PyTrees of the same structure as each other.

* Crash and test fixes

* tidy up

* Fixes and test fixes

* Finished adding tests.

* added is_leaf argument to filter and partition (#72)

* Improved tree_at a lot: (#73)

- Can now substitute arbitrary nodes, not just leaves
- These substituted nodes can now have different structures to each other.
- Will raise an error if `where` depends on the values of the leaves of the PyTree.

In addition, `tree_equal` should now work under JIT.

* Lots of doc fixes and tweaks (#74)

* Added tree_{de,}serialise_leaves (#80)

* Added max/average pooling. (#77)

* Tidy up pooling implementation (#81)

* Version bump

* Fixed tree_at breaking when replacing Nones. (#83)

Co-authored-by: Ben Walker <[email protected]>
  • Loading branch information
patrick-kidger and Benjamin-Walker authored May 6, 2022
1 parent 9325144 commit 291c4d7
Show file tree
Hide file tree
Showing 40 changed files with 2,970 additions and 384 deletions.
5 changes: 5 additions & 0 deletions docs/_static/custom_css.css
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
/* Fix /page#foo going to the top of the viewport and being hidden by the navbar */
html {
scroll-padding-top: 50px;
}

/* Fit the Twitter handle alongside the GitHub one in the top right. */

div.md-header__source {
Expand Down
4 changes: 2 additions & 2 deletions docs/all-of-equinox.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,6 @@ See also the API reference on the left.

!!! faq "FAQ"

One common question that gets asked: a lot of other libraries introduce custom `library.jit` etc. operations, specifically to work with `library.Module`. What makes the filtered transformations of Equinox different?
One common question: a lot of other libraries introduce custom `library.jit` etc. operations, specifically to work with `library.Module`. What makes the filtered transformations of Equinox different?

The answer is that filters are tools that apply to *any PyTree*; and models just happen to be PyTrees. There is no special coupling between a filtered transformation and `eqx.Module`. (Not to mention that the purpose of filtered transformations -- being able to include non-JAX-arrays in your model's parameters -- is itself unusual/impossible in many other libraries.)
The answer is that filter transformations are tools that apply to any PyTree. And models just happen to be PyTrees. The filtered transformations and `eqx.Module` are not coupled together.
42 changes: 23 additions & 19 deletions docs/api/stateful.md → docs/api/experimental/stateful.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ These operations can be used to introduce save/load JAX arrays as a side-effect

This is considered experimental.

Stateful operations will not produce correct results under `jax.checkpoint` or `jax.pmap`.

!!! danger

Really, **this is experimental**. Side effects can easily make your code do something unexpected. Whatever you're doing, you almost certainly do not need this.
Expand All @@ -15,33 +17,35 @@ Use cases:
- Something like [`equinox.experimental.BatchNorm`][], for which we would like to save the running statistics as a side-effect.
- Implicitly passing information between loop iterations -- i.e. rather than explicitly via the `carry` argument to `lax.scan`. Perhaps you're using a third-party library that handles the `lax.scan`, that doesn't allow you pass your own information between iterations.

Example:
```python
import equinox as eqx
import jax
import jax.lax as lax
import jax.numpy as jnp
!!! example

```python
import equinox as eqx
import jax
import jax.lax as lax
import jax.numpy as jnp

index = eqx.experimental.StateIndex()
init = jnp.array(0)
eqx.experimental.set_state(index, init)
index = eqx.experimental.StateIndex()
init = jnp.array(0)
eqx.experimental.set_state(index, init)

@jax.jit
def scan_fun(_, __):
val = eqx.experimental.get_state(index, like=init)
val = val + 1
eqx.experimental.set_state(index, val)
return None, val
@jax.jit
def scan_fun(_, __):
val = eqx.experimental.get_state(index, like=init)
val = val + 1
eqx.experimental.set_state(index, val)
return None, val

_, out = lax.scan(scan_fun, None, xs=None, length=5)
print(out) # [1 2 3 4 5]
```
_, out = lax.scan(scan_fun, None, xs=None, length=5)
print(out) # [1 2 3 4 5]
```

---

::: equinox.experimental.StateIndex
selection:
members: false
members:
- __init__

---

Expand Down
10 changes: 9 additions & 1 deletion docs/api/filtering/filtered-transformations.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Filtered transformations

These typically combine [`equinox.partition`][], a filter function, and a JAX transformation, all together.
These typically combine [`equinox.partition`][], a [filter function](./filter-functions.md), and a JAX transformation, all together.

Practically speaking these are usually the only kind of filtering you ever have to use. (But it's good to understand what e.g. [`equinox.partition`][] and [`equinox.is_array`][] are doing under the hood, just so that these don't seem too magical.)

Expand All @@ -19,3 +19,11 @@ Practically speaking these are usually the only kind of filtering you ever have
::: equinox.filter_custom_vjp
selection:
members: false

---

::: equinox.filter_vmap

---

::: equinox.filter_pmap
43 changes: 43 additions & 0 deletions docs/api/nn/pool.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Pooling

::: equinox.nn.Pool
selection:
members:
- __init__
- __call__

---

::: equinox.nn.AvgPool1D
selection:
members: false

---

::: equinox.nn.AvgPool2D
selection:
members: false

---

::: equinox.nn.AvgPool3D
selection:
members: false

---

::: equinox.nn.MaxPool1D
selection:
members: false

---

::: equinox.nn.MaxPool2D
selection:
members: false

---

::: equinox.nn.MaxPool3D
selection:
members: false
11 changes: 11 additions & 0 deletions docs/api/utilities/manipulation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Manipulating PyTrees

::: equinox.apply_updates

---

::: equinox.tree_at

---

::: equinox.tree_inference
10 changes: 1 addition & 9 deletions docs/api/helpers.md → docs/api/utilities/miscellaneous.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,4 @@
# Helpers for PyTrees

::: equinox.apply_updates

---

::: equinox.tree_at

---
# Miscellaneous

::: equinox.tree_pformat

Expand Down
7 changes: 7 additions & 0 deletions docs/api/utilities/serialisation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Serialisation

::: equinox.tree_serialise_leaves

---

::: equinox.tree_deserialise_leaves
9 changes: 4 additions & 5 deletions docs/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ Recall that in Equinox, models are PyTrees. Meanwhile, JAX treats all PyTrees as

The resolution is simple: just don't store the same object in multiple places in the PyTree.

## I cannot feed higher-order tensors (e.g. with batch dimensions) into my model.
## How do I input higher-order tensors (e.g. with batch dimensions) into my model?

Use [`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap). This maps arbitrary JAX operations -- including any Equinox module -- over additional dimensions.
Use [`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap). This maps arbitrary JAX operations -- including any Equinox module -- over additional dimensions (such as batch dimensions).

For example if `x` is an array/tensor of shape `(batch_size, input_size)`, the following code in PyTorch:
For example if `x` is an array/tensor of shape `(batch_size, input_size)`, then the following PyTorch code:

```python
import torch
Expand All @@ -50,8 +50,7 @@ linear = torch.nn.Linear(input_size, output_size)
y = linear(x)
```

is equivalent to the following code in Equinox:

is equivalent to the following Equinox code:
```python
import jax
import equinox as eqx
Expand Down
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mkdocs-material==7.3.6 # Theme
pymdown-extensions==9.4 # Markdown extensions e.g. to handle LaTeX.
mkdocstrings==0.17.0 # Autogenerate documentation from docstrings.
mknotebooks==0.7.1 # Turn Jupyter Lab notebooks into webpages.
pytkdocs_tweaks==0.0.4 # Tweaks mkdocstrings to improve various aspects
pytkdocs_tweaks==0.0.5 # Tweaks mkdocstrings to improve various aspects
mkdocs_include_exclude_files==0.0.1 # Tweak which files are included/excluded
jinja2==3.0.3 # Older version. After 3.1.0 seems to be incompatible with current versions of mkdocstrings.

Expand Down
6 changes: 4 additions & 2 deletions equinox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
from .jit import filter_jit
from .module import Module, static_field
from .pretty_print import tree_pformat
from .tree import tree_at, tree_equal
from .serialisation import tree_deserialise_leaves, tree_serialise_leaves
from .tree import tree_at, tree_equal, tree_inference
from .update import apply_updates
from .vmap_pmap import filter_pmap, filter_vmap


__version__ = "0.4.0"
__version__ = "0.5.0"
46 changes: 46 additions & 0 deletions equinox/compile_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import functools as ft
from typing import Any

import jax

from .filters import combine, partition
from .module import Module, static_field


def hashable_partition(pytree, filter_spec):
dynamic, static = partition(pytree, filter_spec)
static_leaves, static_treedef = jax.tree_flatten(static)
static_leaves = tuple(static_leaves)
return dynamic, static_leaves, static_treedef


def hashable_combine(dynamic, static_leaves, static_treedef):
static = jax.tree_unflatten(static_treedef, static_leaves)
return combine(dynamic, static)


class Static(Module):
value: Any = static_field()


def strip_wrapped_partial(fun):
if hasattr(fun, "__wrapped__"): # ft.wraps
return strip_wrapped_partial(fun.__wrapped__)
if isinstance(fun, ft.partial):
return strip_wrapped_partial(fun.func)
return fun


def compile_cache(fun):
@ft.lru_cache(maxsize=None)
def _cache(leaves, treedef):
args, kwargs = jax.tree_unflatten(treedef, leaves)
return fun(*args, **kwargs)

@ft.wraps(fun)
def _fun(*args, **kwargs):
leaves, treedef = jax.tree_flatten((args, kwargs))
leaves = tuple(leaves)
return _cache(leaves, treedef)

return _fun
9 changes: 8 additions & 1 deletion equinox/custom_types.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import inspect
import typing
from typing import Generic, Tuple, TypeVar, Union
from typing import Any, Callable, Generic, Tuple, TypeVar, Union

import jax

from .doc_utils import doc_repr


# Custom flag we set when generating documentation.
# We do a lot of custom hackery in here to produce nice-looking docs.
Expand Down Expand Up @@ -105,4 +107,9 @@ def __class_getitem__(cls, item):
return PyTree


sentinel = doc_repr(object(), "sentinel")

TreeDef = type(jax.tree_structure(0))

ResolvedBoolAxisSpec = bool
BoolAxisSpec = Union[ResolvedBoolAxisSpec, Callable[[Any], ResolvedBoolAxisSpec]]
32 changes: 32 additions & 0 deletions equinox/doc_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import typing
from types import FunctionType
from typing import Any


# Inherits from type so that _WithRepr instances are types and can be used as
# e.g. Sequence[_WithRepr(...)]
class _WithRepr(type):
def __new__(self, string):
out = super().__new__(self, string, (), {})
# prevent the custom typing repr from doing the wrong thing
out.__module__ = "builtins"
return out

def __init__(self, string):
self.string = string

def __repr__(self):
return self.string


def doc_repr(obj: Any, string: str):
if getattr(typing, "GENERATING_DOCUMENTATION", False):
return _WithRepr(string)
else:
return obj


def doc_strip_annotations(fn: FunctionType) -> FunctionType:
if getattr(typing, "GENERATING_DOCUMENTATION", False):
fn.__annotations__ = None
return fn
11 changes: 7 additions & 4 deletions equinox/experimental/batch_norm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Sequence, Union

import jax
import jax.lax as lax
Expand Down Expand Up @@ -60,7 +60,7 @@ class BatchNorm(Module):
bias: Optional[Array["input_size"]]
first_time_index: StateIndex
state_index: StateIndex
axis_name: str
axis_name: Union[str, Sequence[str]]
inference: bool
input_size: int = static_field()
eps: float = static_field()
Expand All @@ -81,15 +81,18 @@ def __init__(
- `input_size`: The number of channels in the input array.
- `axis_name`: The name of the batch axis to compute statistics over, as passed
to `axis_name` in `jax.vmap` or `jax.pmap`.
to `axis_name` in `jax.vmap` or `jax.pmap`. Can also be a sequence (tuple,
list) of strings to compute statistics over multiple named axes.
- `eps`: Value added to the denominator for numerical stability.
- `channelwise_affine`: Whether the module has learnable channel-wise affine
parameters.
- `momentum`: The rate at which to update the running statistics. Should be a
value between 0 and 1 exclusive.
- `inference`: If `False` then the batch means and variances will be calculated
and used to update the running statistics. If `True` then the running
statistics are directly used for normalisation.
statistics are directly used for normalisation. This may be toggled with
[`equinox.tree_inference`][] or overridden during
[`equinox.experimental.BatchNorm.__call__`][].
"""

super().__init__(**kwargs)
Expand Down
1 change: 1 addition & 0 deletions equinox/experimental/spectral_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def __init__(
- `num_power_iterations`: The number of power iterations to apply every time the array is accessed.
- `eps`: Epsilon for numerical stability when calculating norms.
- `inference`: Whether this is in inference mode, at which time no power iterations are performed.
This may be toggled with [`equinox.tree_inference`][].
- `key`: A `jax.random.PRNGKey` used to provide randomness for initialisation. (Keyword only argument.)
"""
super().__init__(**kwargs)
Expand Down
Loading

0 comments on commit 291c4d7

Please sign in to comment.