-
-
Notifications
You must be signed in to change notification settings - Fork 142
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Improved filtered transformations (#71) * Updated filter_{jit,grad,value_and_grad} * Documented that BatchNorm supports multiple axis_names * Doc fix * added tree_inference * Added filter_vmap and filter_pmap * Made filter_{vmap,pmp} non-experimental. Fixed doc issues. * Added 'inference' call-time argument to 'MultiheadAttention', which was missed when switching over 'Dropout' to use 'inference'. * Added test for tree_inference * Doc improvements * Doc tweaks * Updated tests * Two minor bugfixes for filter_{jit,grad,vmap,pmap,value_and_grad}. 1. Applying filter_{jit,vmap,pmap} to a function with *args and **kwargs should now work. Previously inspect.Signature.apply_defaults was not filling these in. 2. The output of all filtered transformations has become a Module. This means that e.g. calling filter_jit(filter_grad(f)) multiple times will not induce any recompilation, as all filter_grad(f)s are PyTrees of the same structure as each other. * Crash and test fixes * tidy up * Fixes and test fixes * Finished adding tests. * added is_leaf argument to filter and partition (#72) * Improved tree_at a lot: (#73) - Can now substitute arbitrary nodes, not just leaves - These substituted nodes can now have different structures to each other. - Will raise an error if `where` depends on the values of the leaves of the PyTree. In addition, `tree_equal` should now work under JIT. * Lots of doc fixes and tweaks (#74) * Added tree_{de,}serialise_leaves (#80) * Added max/average pooling. (#77) * Tidy up pooling implementation (#81) * Version bump * Fixed tree_at breaking when replacing Nones. (#83) Co-authored-by: Ben Walker <[email protected]>
- Loading branch information
1 parent
9325144
commit 291c4d7
Showing
40 changed files
with
2,970 additions
and
384 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Pooling | ||
|
||
::: equinox.nn.Pool | ||
selection: | ||
members: | ||
- __init__ | ||
- __call__ | ||
|
||
--- | ||
|
||
::: equinox.nn.AvgPool1D | ||
selection: | ||
members: false | ||
|
||
--- | ||
|
||
::: equinox.nn.AvgPool2D | ||
selection: | ||
members: false | ||
|
||
--- | ||
|
||
::: equinox.nn.AvgPool3D | ||
selection: | ||
members: false | ||
|
||
--- | ||
|
||
::: equinox.nn.MaxPool1D | ||
selection: | ||
members: false | ||
|
||
--- | ||
|
||
::: equinox.nn.MaxPool2D | ||
selection: | ||
members: false | ||
|
||
--- | ||
|
||
::: equinox.nn.MaxPool3D | ||
selection: | ||
members: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Manipulating PyTrees | ||
|
||
::: equinox.apply_updates | ||
|
||
--- | ||
|
||
::: equinox.tree_at | ||
|
||
--- | ||
|
||
::: equinox.tree_inference |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Serialisation | ||
|
||
::: equinox.tree_serialise_leaves | ||
|
||
--- | ||
|
||
::: equinox.tree_deserialise_leaves |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import functools as ft | ||
from typing import Any | ||
|
||
import jax | ||
|
||
from .filters import combine, partition | ||
from .module import Module, static_field | ||
|
||
|
||
def hashable_partition(pytree, filter_spec): | ||
dynamic, static = partition(pytree, filter_spec) | ||
static_leaves, static_treedef = jax.tree_flatten(static) | ||
static_leaves = tuple(static_leaves) | ||
return dynamic, static_leaves, static_treedef | ||
|
||
|
||
def hashable_combine(dynamic, static_leaves, static_treedef): | ||
static = jax.tree_unflatten(static_treedef, static_leaves) | ||
return combine(dynamic, static) | ||
|
||
|
||
class Static(Module): | ||
value: Any = static_field() | ||
|
||
|
||
def strip_wrapped_partial(fun): | ||
if hasattr(fun, "__wrapped__"): # ft.wraps | ||
return strip_wrapped_partial(fun.__wrapped__) | ||
if isinstance(fun, ft.partial): | ||
return strip_wrapped_partial(fun.func) | ||
return fun | ||
|
||
|
||
def compile_cache(fun): | ||
@ft.lru_cache(maxsize=None) | ||
def _cache(leaves, treedef): | ||
args, kwargs = jax.tree_unflatten(treedef, leaves) | ||
return fun(*args, **kwargs) | ||
|
||
@ft.wraps(fun) | ||
def _fun(*args, **kwargs): | ||
leaves, treedef = jax.tree_flatten((args, kwargs)) | ||
leaves = tuple(leaves) | ||
return _cache(leaves, treedef) | ||
|
||
return _fun |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import typing | ||
from types import FunctionType | ||
from typing import Any | ||
|
||
|
||
# Inherits from type so that _WithRepr instances are types and can be used as | ||
# e.g. Sequence[_WithRepr(...)] | ||
class _WithRepr(type): | ||
def __new__(self, string): | ||
out = super().__new__(self, string, (), {}) | ||
# prevent the custom typing repr from doing the wrong thing | ||
out.__module__ = "builtins" | ||
return out | ||
|
||
def __init__(self, string): | ||
self.string = string | ||
|
||
def __repr__(self): | ||
return self.string | ||
|
||
|
||
def doc_repr(obj: Any, string: str): | ||
if getattr(typing, "GENERATING_DOCUMENTATION", False): | ||
return _WithRepr(string) | ||
else: | ||
return obj | ||
|
||
|
||
def doc_strip_annotations(fn: FunctionType) -> FunctionType: | ||
if getattr(typing, "GENERATING_DOCUMENTATION", False): | ||
fn.__annotations__ = None | ||
return fn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.