Skip to content

Equinox v0.5.0

Compare
Choose a tag to compare
@github-actions github-actions released this 06 May 21:24
· 696 commits to main since this release
291c4d7

This is a big update.

Exciting new features!

  • Added filter_vmap.

    • This can be used to create ensembles of models.
    • (Closes #65.)
  • Added filter_pmap.

    • (Closes #65.)
  • Added pooling layers:

    • eqx.nn.Pool
    • eqx.nn.AvgPool1d
    • eqx.nn.AvgPool2d
    • eqx.nn.AvgPool3d
    • eqx.nn.MaxPool1d
    • eqx.nn.MaxPool2d
    • eqx.nn.MaxPool3d
    • (Closes #59.)
    • (Thanks to @Benjamin-Walker for implementing this.)
  • Added tree_serialise_leaves and tree_deserialise_leaves.

    • This can be used to save and load models to file.
    • (Closes #46.)
    • (Thanks to @jaschau for helpful discussions on this.)
  • Added tree_inference, as a convenience for toggling all inference flags through a model.

Refactoring for nicer APIs

  • filter_{jit,grad,value_and_grad} now have an easier-to-use API for specifying which arguments have what behaviour.

    • Instead of having to specify (args, kwargs) as a single PyTree, then you can specify a default, args, kwargs separately. In particular this avoids doing messy stuff like filter_spec=((...), {}) when you had no kwargs.
    • You no longer have to match up the filter specification for args and kwargs against their runtime values. Both the runtime values, and the filter specification, are matched up against the function signature.
      e.g. you can do filter_jit(lambda x: x, kwargs=dict(x=True))(1), using a keyword argument at JIT-time and a positional argument at call time.
    • Currying is available: both filter_jit(fun) and filter_jit(default=...)(fun) will work.
    • The old API is still available for backward compatibility, of course.
    • (Closes #48.)
  • tree_at can now replace subtrees, and not just leaves.

    • (Closes #47.)
  • filter, partition now support an is_leaf argument.

    • (Closes #68.)

Miscellaneous

  • Calling filter_jit(filter_grad(fun)) twice will no longer lead to unnecessary recompilation: the second filter_grad(fun) instance will be a PyTree that looks like the first filter_grad(fun) instance, and thus we won't get any recompilation.

Full Changelog: v0.4.0...v0.5.0