Equinox v0.5.0
This is a big update.
Exciting new features!
-
Added
filter_vmap
.- This can be used to create ensembles of models.
- (Closes #65.)
-
Added
filter_pmap
.- (Closes #65.)
-
Added pooling layers:
eqx.nn.Pool
eqx.nn.AvgPool1d
eqx.nn.AvgPool2d
eqx.nn.AvgPool3d
eqx.nn.MaxPool1d
eqx.nn.MaxPool2d
eqx.nn.MaxPool3d
- (Closes #59.)
- (Thanks to @Benjamin-Walker for implementing this.)
-
Added
tree_serialise_leaves
andtree_deserialise_leaves
. -
Added
tree_inference
, as a convenience for toggling all inference flags through a model.
Refactoring for nicer APIs
-
filter_{jit,grad,value_and_grad}
now have an easier-to-use API for specifying which arguments have what behaviour.- Instead of having to specify
(args, kwargs)
as a single PyTree, then you can specify adefault
,args
,kwargs
separately. In particular this avoids doing messy stuff likefilter_spec=((...), {})
when you had no kwargs. - You no longer have to match up the filter specification for
args
andkwargs
against their runtime values. Both the runtime values, and the filter specification, are matched up against the function signature.
e.g. you can dofilter_jit(lambda x: x, kwargs=dict(x=True))(1)
, using a keyword argument at JIT-time and a positional argument at call time. - Currying is available: both
filter_jit(fun)
andfilter_jit(default=...)(fun)
will work. - The old API is still available for backward compatibility, of course.
- (Closes #48.)
- Instead of having to specify
-
tree_at
can now replace subtrees, and not just leaves.- (Closes #47.)
-
filter
,partition
now support anis_leaf
argument.- (Closes #68.)
Miscellaneous
- Calling
filter_jit(filter_grad(fun))
twice will no longer lead to unnecessary recompilation: the secondfilter_grad(fun)
instance will be a PyTree that looks like the firstfilter_grad(fun)
instance, and thus we won't get any recompilation.- This is actually an improvement over standard JAX! See https://github.com/google/jax/discussions/10284.
Full Changelog: v0.4.0...v0.5.0