Releases: patrick-kidger/equinox
Equinox v0.11.9
This is a (important) bugfix release.
- Fix filter_vmap with out_axes!=0,1 producing outputs with the wrong axis order. (Thanks @remifan! #900, #901)
Full Changelog: v0.11.8...v0.11.9
Equinox v0.11.8
The main thing for this release is JAX 0.4.34 compatibility -- JAX introduced breaking changes in this release that we are now compatible with. (#871)
Bugfixes
- Accessing the concrete implementation of an abstract class attribute within
__init_subclass__
should no longer crash. (Plus probably better-behaved__init_subclass__
overall.)
Miscellaneous
- JAX 0.4.33 introduced a change that broke
eqx.error_if
's nice displaying of error message. With this release then we are back to having nice error messages again! eqx.nn.StateIndex
can now be passed throughjax.jit
(and not justeqx.filter_jit
). (Thanks @NeilGirdhar! #843)- Normalization layers now upcast to at least 32-bit precision. (Thanks @AakashKumarNain! #876)
- Poetry has a bug in its interpretation of
~=
version constraints. We now work around that for better compatibility with certain kinds of Poetry installations. (Thanks @norpadon! #878)
Documentation
- Updated CNN example to work with recent JAX versions. (Thanks @pasq-cat! #880, #881)
- Update
eqx.tree_at
documentation for clarity. (Thanks @jeertmans! #872, #874, #877)
New Contributors
Full Changelog: v0.11.7...v0.11.8
Equinox v0.11.7
Quick release. JAX 0.4.32 / 0.4.33 just introduced a breaking change; this release ensures Equinox is compatible with this. (#856)
Full Changelog: v0.11.6...v0.11.7
Equinox v0.11.6
This is primarily a bug fix release.
-
Runtime error messages (those from
eqx.error_if
, in particular when wrapped witheqx.filter_jit
) should now be compatible with PyCharm's debugger, and with certain multithreaded contexts. (Thanks @adam-hartshorne, @dlwh! #828, #844, #849) -
Marking a
jax.Array
ornp.ndarray
as aneqx.field(static=True)
will now raise a warning. This was technically okay as long as you use it in certain very narrow contexts (e.g. to smuggle it into a JIT'd region without being traced), but in practice it was nearly always just a common new-user footgun. (Thanks @lockwo! #800) -
Using
eqx.tree_at
for replacing empty tuples is improved. (Thanks @danielward27! #818, #819) -
eqx.nn.RotaryEmbedding
no longer promote input dtypes to at least float32. (Thanks @knyazer! #836) -
Mypy now understands that
eqx.Module
s are dataclasses. (Pyright always did, but mypy needed a slightly different approach to appreciate this fact.) (Thanks @NeilGirdhar! #822) -
Multiple
eqx.Module
s participating in co-operative multiple inheritance (at least 5 inheriting from each other seem to be necessary?), with some of them overriding the__post_init__
s of others, should now follow their expected resolution order. (Thanks @NeilGirdhar! #832, #834) -
We now have a
.editorconfig
file, (thanks @NeilGirdhar! #821) -
Doc improvements. (Thanks @garymm, @ColCarroll! #804, #805)
New Contributors
- @garymm made their first contribution in #804
- @ColCarroll made their first contribution in #805
- @NeilGirdhar made their first contribution in #823
Full Changelog: v0.11.5...v0.11.6
Equinox v0.11.5
JAX compatibility
Recent versions of JAX (0.4.28+) have made some changes to:
- Hashing of tracers;
- Tree-map'ing over Nones;
- Callbacks;
- Pretty-printing.
With this update, we should now be compatible with both old and new versions of JAX: this fixes both some new crashes, and some new warnings. (#719, #724, #753, #758, thanks @jakevdp, @hawkinsp!)
Better errors
-
The error messages from
eqx.error_if
are now substantially more informative: they include traceback information including the stack, and mention the availability of theEQX_ON_ERROR
variable. We also do a much better job hiding the large unhelpful printouts that XLA gives by default. (#785, #803) -
The default value of
EQX_ON_ERROR_BREAKPOINT_FRAMES
is now1
. (#777) The impact of this is that usingeqx.error_if
alongsideEQX_ON_ERROR=breakpoint
will now:- reliably always open a debugger, rather than sometimes crashing at trace-time due to upstream JAX bug #16732.
- however, by default the debugger will no longer include any additional stack frames above it (accessed via
u
). - much of the above is now explained in a printed-out informative message prior to the debugger opening.
Bugfixes
-
eqx.filter_{jacfwd, jacrev}
now only apply filtering to their inputs but not their outputs. Previously this was problematic as there was no way to represent static-input-by-static-output in the returned Jacobian, so pieces were silently dropped. (#734, thanks @lockwo!) -
eqx.tree_at
can now be used to replace empty tuples. (#715, #717, #722, thanks @lockwo!) -
eqx.filter_custom_jvp
no longer raises a trace-time crash in some scenarios in which its**kwargs
were erroneously counted as having tangents. (#745 (comment), #749) -
No longer getting a trace-time crash when doing a particular combination of vmap + autodiff + checkpointed while loops. This occurred when using
optimistix.BFGS
arounddiffrax.diffeqsolve
. (#777) -
Fixed a trace-time crash when:
- using a checkpointed while loop...
- ...with a body function that has a closed-over tracer...
- ...and that closed-over tracer is differentiated...
- ...and there are no other closed-over tracers that are differentiated...
- ...and the dependency on that tracer is only linear.
- (patrick-kidger/diffrax#387 (comment), #752, thanks @dkweiss31!)
-
Fixed a trace-time crash when composing the grad of vmap of
lineax.linear_solve
. (patrick-kidger/lineax#101, #795, thanks @rhacking!) -
eqx.nn.RMSNorm
now uses at least 32-bit precision for numerical stability (#723, thanks @AakashKumarNain!)
New features
-
eqx.nn.{Linear,Conv,GRUCell,LSTMCell}
now support complex dtypes (#765, thanks @ChenAo-Phys!) -
Added
eqx.nn.RotaryEmbedding(..., theta=...)
. (#735, thanks @Artur-Galstyan!)
Other changes
-
Several doc fixes. (#708, #731, #733, #747, #750, #757 + several other PRs, thanks @Artur-Galstyan, @matteoguarrera, @lockwo, @nasyxx!)
-
Several internal test fixes as downstream libraries have changed slightly. (#740, #742 + several other PRs, big thanks to @GaetanLepage for reporting many of these!)
-
There is now a Mistral 7B implementation using JAX+Equinox available over in AakashKumarNain/mistral_jax!
New Contributors
- @nasyxx made their first contribution in #708
- @jakevdp made their first contribution in #724
- @matteoguarrera made their first contribution in #739
Full Changelog: v0.11.4...v0.11.5
Equinox v0.11.4
Features
-
Added
eqx.filter_shard
. This lowers tojax.lax.with_sharding_constraint
as a single way to transfer data, or reshard data, both inside and outside of JIT! (No morejax.device_put
.) In addition, the parallelism example has been updated to use this simpler new functionality. (Thanks @homerjed and @dlwh! #688, #691) -
Added
eqx.filter_{jacfwd,jacrev,hessian}
. These do what you expect! (Thanks @lockwo! #677) -
Added
eqx.nn.RotaryPostionalEmbedding
. This is designed to be used in conjunction with the existingeqx.nn.MultiheadAttention
. (Thanks @Artur-Galstyan! #568) -
Added support for
padding='VALID'
,padding='SAME'
,padding='SAME_LOWER'
to the convolutional layers:eqx.nn.{Conv, ...}
. (Thanks @ChenAo-Phys! #658) -
Added support for
padding_mode='ZEROS'
,padding_mode='REFLECT'
,padding_mode='REPLICATE'
,padding_mode='CIRCULAR'
to the convolutional layers:eqx.nn.{Conv, ...}
. (Thanks @ChenAo-Phys! #658) -
Added a
dtype
argument toeqx.nn.{MultiheadAttention, Linear, Conv, ...}
for specifying the dtype of their parameters. In additioneqx.nn.BatchNorm
will now also uses itsdtype
argument to determine the dtype of its weights and bias, not just the dtype of its moving statistics. (Thanks @Artur-Galstyan and @AakashKumarNain! #680, #689)
Compatibility
-
eqx.error_if
is now compatible with JAX 0.4.26, which changed JAX's own reporting of error messages slightly. (Thanks @hawkinsp! #670) -
Added a warning that checks for doing something like:
class MyModule(eqx.Module): fn: Callable def __init__(self, ...): self.fn = jax.vmap(some_fn)
As this is an easy source of bugs. (The vmap'd function is not a PyTree so will not propagate anything in the PyTree stucture of
some_fn
.)
Technical internal stuff
-
eqx.internal.while_loop(..., kind="checkpointed")
will now only propagate forward JVP tracers for those outputs which are perturbed due to the input to the loop being perturbed. (Rather than all of them.) This change just means that later calls to a nondifferentiable operation, likejax.pure_callback
oreqx.internal.nondifferentiable
, will no longer crash at trace time. (See patrick-kidger/diffrax#396.) -
eqx.internal.while_loop(..., kind="bounded")
will now handle certain vmap+grad combinations without crashing. (It seems like JAX is adding some spurious batch tracers.) (See patrick-kidger/optimistix#48 (comment)) -
the transpose rule for
eqx.internal.create_vprim
now understands symbolic zeros, fixing a crash forgrad-of-vmap-of-<lineax.linear_solve that we only use some outputs from>
. (See patrick-kidger/optimistix#48.) -
The type annotation for the input of any converter function used in
eqx.field(converter=...)
will now be used as the type annotation in anydataclass
-autogenerated__init__
functions. In particular this should mean such functions are now compatible with runtime type checkers like beartype. (jaxtyping users, you were already covered: this checks the assigned annotations instead.)
New Contributors
- @ChenAo-Phys made their first contribution in #658
- @hawkinsp made their first contribution in #670
- @AakashKumarNain made their first contribution in #680
- @imilas made their first contribution in #699
Full Changelog: v0.11.3...v0.11.4
Equinox v0.11.3
Features
- Added
equinox.nn.RMSNorm
. - Added
equinox.nn.WeightNorm
. equinox.tree_deserialise_leaves
now treatsjax.ShapeDtypeStruct
s in the same way as arrays. This makes it possible to avoid instantiating the initial model parameters only to throw them away again, by usingequinox.filter_eval_shape
:(#259)model = eqx.filter_eval_shape(Model, ...hyperparameters...) model = eqx.tree_deserialise_leaves(load_path, model)
Bugfixes
equinox.internal.noinline
no longer initialises the JAX backend on use.equinox.filter_jit(...).lower(..., some_kwarg=...)
no longer crashes (#625, #627)- The state of
equionx.nn.BatchNorm
now uses the default floating point dtype, rather than always usingfloat32
. equinox.nn.MultiheadAttention
should now perform the softmax infloat32
even when the input is of lower dtype. (This is important for numerical stability.)
Refactor
- All the layers in
equinox.nn.{Linear, MLP, ...}
now standardise on accepting extra**kwargs
and not callingsuper().__init__
. The intention is that these layers be treated as final, i.e. not subclassable. (Previously things were inconsistent: some did this and some did not.) - Should now be compatible with
JAX_NUMPY_DTYPE_PROMOTION=strict
andJAX_NUMPY_RANK_PROMOTION=raise
, and this is checked in tests. - Better error message when no kwargs passed to
filter_grad
(Thanks @knyazer! #589)
Internal features
These are undocumented internal features, that may be changed at any time.
- Added
EQX_GETKEY_SEED
for use withequinox.internal.GetKey
. equinox.internal.while_loop
now has its runtime errors removed. This should help with compatibility with TPUs. (#628)
New Contributors
- @haydn-jones made their first contribution in #608
Full Changelog: v0.11.2...v0.11.3
Equinox v0.11.2
Features
- Added
eqx.filter_jit(..., donate="all-except-first")
andeqx.filter_jit(..., donate="warn-except-first")
. This offers a way to donate all arguments except the first one. (If you have multiple such arguments then just pack them together into a tuple in the first argument.) This aims to be a low-overhead easy way to handle buffer donation. - Added
eqx.debug.{assert_max_traces, get_num_traces}
, which aim to provide a friendly way of asserting that a JIT'd function is not recompiled -- and if it is, which argument changed to cause the recompilation. eqx.tree_pprint
andeqx.tree_pformat
now handle PyTorch tensors andjax.ShapeDtypeStruct
s.eqx.tree_equal
now has new arguments:typematch=True
: this will require that every leaf have precisely the same type as each other, i.e. right now the requirement is essentiallyleaf == leaf2
; with this flag it becomestype(leaf) == type(leaf2) and leaf == leaf2
.rtol
andatol
: setting these to nonzero values allows for checking that inexact (floating or complex) arrays are allclose, rather than exactly equal.- The expectation is that these will be useful in unit tests, e.g. to write checks of the form
assert eqx.tree_equal(output, expected_output, typematch=True, rtol=1e-5, atol=1e-5)
.
Bugfixes
- Previously, a learnt activation function for
eqx.nn.MLP
would use the exact same learnt weights for every neuron in every layer. Now, a separate copy of the activation function is used in each location. - Subclasses of
eqx.Module
should now have their__init__
signatures correctly reported by downstream tooling, e.g. automated doc generators, some IDEs. (Thanks @danielward27! #573)
Typing
eqx.filter_value_and_grad
now declares that it preserves the return type of its function (Thanks @ConnorBaker! #557)
Documentation
- Fix missing index argument in docstring example for
StateIndex
(Thanks @edwardwli! #556) - Fixed broken link in
eqx.Enumueration
docstrings (Thanks @LouisDesdoigts! #579) - Fixed missing shape specification by in one of the tricks. (Thanks @homerjed! #582)
Other
- Improved a few IPython tracebacks with appropriate
__tracebackhide__ = True
assignments. - Subclassed
eqx.Enumeration
s can now override the message associated with their parent Enumeration: this now produces a warning rather than an error. - Documented the
EQX_ON_ERROR_BREAKPOINT_FRAMES
config variable, which is used to work around a JAX bug when settingEQX_ON_ERROR=breakpoint
. - Can now monkey-patch the methods of an
eqx.Module
, e.g.the anticipated use-case for this is to make it easier for typecheckers; see #584.class Foo(eqx.Module): def f(self): ... Foo.f = some_transform(Foo.f)
eqx.debug.store_dce
now supports non-arrays in its argument.eqx.Enumeration.where(traced_pred, x, x)
will now statically returnx
without tracing. This is occasionally useful to better propagate information at compile time.
Internal features (not officially supported, advanced use only)
- Added
eqx.internal.GetKey
. This generates a random JAX PRNG key when called, and crucially has a nice__repr__
reporting what the seed value is. This should not be used in normal JAX code! This is intended as a convenience for tests, so that the random seed appears in the debug printout of a failed test. - Added
eqx.internal.MaybeBuffer
to indicate that an argument of aneqx.internal.{while_loop,scan}
might be wrapped in a buffer. - Added
eqx.internal.buffer_at_set
to supportbuffer.at[...].set(..., pred=...)
whilst being agnostic to whetherbuffer
is a JAX array or one of our while loop buffers.
New Contributors
- @edwardwli made their first contribution in #556
- @ConnorBaker made their first contribution in #557
- @danielward27 made their first contribution in #573
Full Changelog: v0.11.1...v0.11.2
Equinox v0.11.1
This is a minor bugfix release.
Bugfixes
- Checkpointed while loops (
eqx.internal.while_loop(..., kind="checkpointed")
) now perform a more careful analysis of which arguments need to be differentiated. (#548) This fix is the primary reason for this release -- it unlocks some efficiency improvements when solving SDEs in Diffrax: patrick-kidger/diffrax#320 - Fixed
Abstract{Class,}Var
misbehaving around multiple inheritance. (#544) - Better compatibility with the beartype library. In a few cases this was throwing some spurious errors to do with forward references. (#543)
Documentation
Other
- Static type checkers should now use Equinox's type hints correctly. (Specfically, we now have the
py.typed
marker file. Thanks @vidhanio! #547) - Added the
EQX_ON_ERROR_BREAKPOINT_FRAMES
environment variable, to work around JAX bug jax-ml/jax#16732 when usingEQX_ON_ERROR=breakpoint
. This new variable sets the number of stack frames you can access via theu
debugger command, when the on-error debugger is triggered. Set this to a small enough number, e.g.EQX_ON_ERROR_BREAKPOINT_FRAMES=1
, and it should fix unusual trace-time errors when usingEQX_ON_ERROR=breakpoint
.
New Contributors
Full Changelog: v0.11.0...v0.11.1
Equinox v0.11.0
Better errors
Equinox now includes several additional checks to guard against various bugs. If you have a new error, then this is probably an indication that your code always had a silent bug, and should be updated.
eqx.nn.LayerNorm
now correctly validates that the shape of its input. This was a common cause of silent bugs. (Thanks @dlwh for pointing this one out!)- Equinox now prints out a warning if you supply both
__init__
and__post_init__
-- the former actually overwrites the latter. (This is normal Python dataclass behaviour, but probably unexpected.) - Equinox now prevents you from assigning Module attributes with a bound method of your current instance, e.g.
Otherwise, you end up with two different copies of your model! One at
class Model(eqx.Module): foo: Callable def __init__(self): self.foo = self.bar def bar(self): ...
self
, the other atself.foo.__self__
. (The latter being in the bound method.) eqx.tree_at
now gives a better error message if you use it try to and update something that isn't a PyTree leaf. (Thanks @LouisDesdoigts!)
API changes
These should all be very minor.
- Breaking change:
eqx.nn.StateIndex
now takes the initial value, rather than a function that returns the initial value. - Breaking change: If using
eqx.field(converter=...)
, then conversion now happens before__post_init__
, rather than after it. - Prefer
eqx.nn.make_with_state
overeqx.nn.State
. The latter will continue to work, but the former is more memory-efficient. (It deletes the original copy of the initial state.) - Prefer
eqx.nn.inference_mode
overeqx.tree_inference
. The latter will continue to exist for backward compatibility. These are the same function, this is really just a matter of moving it into theeqx.nn
namespace where it always belonged.
Sharing layers
Equinox now supports sharing a layer between multiple parts of your model! This has probably been our longest-requested feature -- in large part because of how intractable it seemed. Equinox models are PyTrees, not PyDAGs, so how exactly are we supposed to have two different parts of our model point at the same layer?
The answer turned out to be the following -- in this example, we're reusing the embedding weight matrix between the initial embedding layer, and the final readout layer, of a language model.
class LanguageModel(eqx.Module):
shared: eqx.nn.Shared
def __init__(self):
embedding = eqx.nn.Embedding(...)
linear = eqx.nn.Linear(...)
# These two weights will now be tied together.
where = lambda embed_and_lin: embed_and_lin[1].weight
get = lambda embed_and_lin: embed_and_lin[0].weight
self.shared = eqx.nn.Shared((embedding, linear), where, get)
def __call__(self, tokens):
# Expand back out so we can evaluate these layers.
embedding, linear = self.shared()
assert embedding.weight is linear.weight # same parameter!
# Now go ahead and evaluate your language model.
...
here, eqx.nn.Shared(...)
simply removes all of the nodes at where
, so that we don't have two separate copies. Then when it is called at self.shared()
, it puts them back again. Note that this isn't a copy and doesn't incur any additional memory overhead; this all happens at the Python level, not the XLA level.
(The curious may like to take a look at the implementation in equinox/nn/_shared.py
, which turned out to be very simple.)
On a meta level, I'd like to comment that I'm quite proud of having gotten this one in! It means that Equinox now supports both stateful layers and shared layers, which have always been the two pieces that seemed out of reach when using something as simple as PyTrees to represent models. But it turns out that PyTrees really are all you need. :D
Other changes
Documentation
- Many documentation fixes courtesy of @colehaus and @Artur-Galstyan!
- Added two new examples to the documentation. Thank you to @ahmed-alllam for both of them!
- Deep convolutional GAN
- Vision Transformer
- Added an FAQ entry on comparisons between Equinox and PyTorch/Keras/Julia/Flax. It's a common enough question that should probably have had an answer before now.
- Added an FAQ entry on debugging recompilation.
Features
- Added
eqx.filter_checkpoint
, which as you might expect is a filtered version ofjax.checkpoint
. (Thanks @dlwh!) - Added
eqx.Module.__check_init__
. This is run in a similar fashion to__post_init__
; see the documentation. This can be used to check that invariants of your module hold after initialisation. - Added support for vmap'ing stateful layers, by adding
eqx.nn.State.{substate, update}
. This offers a way to subset or update aState
object, that so only the parts of it that need to be vmap'd are passed in. See the stateful documentation for an example of how to do this. - Runtime error should now produce much more readable results, without any of the terrifying
INTERNAL: Generated function failed: CpuCallback error
stuff! This clean-up of the runtime error message is done byeqx.filter_jit
, so that will need to be your top-level way of JIT'ing your computation. - Added
eqx.nn.StatefulLayer
-- this is (only!) witheqx.nn.Sequential
, to indicate that the layer should be called withx, state
, and not justx
. If you would like a custom stateful layer to be compatible withSequential
then go ahead and subclass this, and potentially implement theis_stateful
method. (Thanks @paganpasta!) - The forward pass of each
eqx.nn.*
layer is now wrapped in ajax.named_scope
, for better debugging experience. (Thanks @ahmed-alllam!) eqx.module_update_wrapper
no longer requires a second argument; it will look at the__wrapped__
attribute of its first argument.- Added
eqx.internal.closure_to_pytree
, for... you guessed it, turning function closures into PyTrees. The closed-over variables are treated as the subnodes in the PyTree. This will operate recursively so that closed-over closures will themselves become PyTrees, etc. Note that closed-over global variables are not included.
Bugfixes
eqx.tree_{serialise,deserialise}_leaves
now correctly handle unusual NumPy scalars, likebfloat16
. (Thanks @colehaus!)eqx.field(metadata=...)
arguments no longer results in thestatic
/converter
arguments being ignored. (Thanks @mjo22!)eqx.filter_custom_vjp
now supports residuals that are not arrays. (The residuals are the pytree that is passed between the forward and backward pass.)eqx.{AbstractVar,AbstractClassVar}
should now support overriden generics in subclasses. That is, something like this:should no longer raise spurious errors under certain conditions.class Foo(eqx.Module): x: eqx.AbstractVar[list[str]] class Bar(Foo): x: list[str]
eqx.internal.while_loop
now supports using custom (non-Equinox) pytrees in the state.eqx.tree_check
no longer raises some false positives.- Equinox modules now support
__init_subclass__
with additional class creation kwargs. (Thanks @ASEM000, @Roger-luo!)
New Contributors
- @homerjed made their first contribution in #445
- @LouisDesdoigts made their first contribution in #460
- @knyazer made their first contribution in #474
Full Changelog: v0.10.11...v0.11.0