Skip to content

Commit

Permalink
Now running tree_check when initialising a Module
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Sep 7, 2023
1 parent 339611d commit de5a982
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 8 deletions.
41 changes: 40 additions & 1 deletion equinox/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ._caches import internal_lru_caches
from ._doc_utils import doc_repr
from ._pretty_print import tree_pformat
from ._tree import tree_equal
from ._tree import tree_check_internal, tree_equal


_P = ParamSpec("_P")
Expand Down Expand Up @@ -110,6 +110,11 @@ def _not_magic(k: str) -> bool:


_has_dataclass_init = weakref.WeakKeyDictionary()
_has_been_checked = weakref.WeakValueDictionary()


def _skip(node):
return isinstance(node, Module) and node is _has_been_checked.get(id(node), None)


# Inherits from ABCMeta as a convenience for a common use-case.
Expand Down Expand Up @@ -201,6 +206,40 @@ def __call__(cls, *args, **kwargs):
else:
setattr(self, field.name, converter(getattr(self, field.name)))
object.__setattr__(self, "__class__", cls)
# Note that this only runs during the initial creation, and not during
# unflattening.
try:
tree_check_internal(self, _skip)
except ValueError as e:
raise ValueError(
"As of Equinox v0.11.0, `equinox.Module`s now validate that there "
"aren't any repeated layers inside a module. This is because this was "
"previously a common bug.\n"
"As an example, something like this:\n"
"```\n`"
"class MyModule(eqx.Module):\n"
" linear1: eqx.nn.Linear\n"
" linear2: eqx.nn.Linear\n"
"\n"
" def __init__(self, ...):\n"
" linear = eqx.nn.Linear(...)\n"
" self.linear1 = linear\n"
" self.linear2 = linear\n"
"```\n"
"resulted in two independent linear layers after a gradient update had "
"happened.\n"
"An exception is being thrown now as this error been detected.\n"
"If you intended to share the layer, then use the new functionality "
"`eqx.nn.Shared`. If you intended to have two duplicate copies, then "
"please instantiate two separate layers. If it's easier, you can also "
"clone an existing layer by doing\n"
"```\n"
"layer = ...\n"
"leaves, treedef = jax.tree_util.tree_flatten(layer)\n"
"clone_layer = jax.tree_util.tree_unflatten(treedef, leaves)\n"
"```"
) from e
_has_been_checked[id(self)] = self
return self


Expand Down
19 changes: 14 additions & 5 deletions equinox/_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,14 @@ def is_leaf(node):
return jtu.tree_flatten(pytree, is_leaf=is_leaf)


def tree_check_internal(pytree, skip) -> None:
"""As `tree_check`, but can skips checking some nodes (typically those that have
alread been checked).
"""
all_nodes = {}
_tree_check(pytree, all_nodes, skip)


def tree_check(pytree: Any) -> None:
"""Checks if the PyTree is well-formed: does it have no self-references, and does
it have no duplicate layers.
Expand Down Expand Up @@ -389,13 +397,13 @@ def tree_check(pytree: Any) -> None:
A `ValueError` if the PyTree is not well-formed.
"""
all_nodes = {}
_tree_check(pytree, all_nodes)
_tree_check(pytree, all_nodes, skip=lambda _: False)


_leaf_treedef = jtu.tree_structure(0)


def _tree_check(node, all_nodes):
def _tree_check(node, all_nodes, skip):
subnodes, treedef = tree_flatten_one_level(node)
# We allow duplicate leaves and empty containers, so don't raise an error with those
if treedef != _leaf_treedef and treedef.num_leaves > 0:
Expand All @@ -422,7 +430,8 @@ def _tree_check(node, all_nodes):
except AttributeError:
# AttributeError: in case we cannot get __name__ for some weird reason.
type_string = "<unknown type>"
all_nodes[id(node)] = (True, type_string)
for subnode in subnodes:
_tree_check(subnode, all_nodes)
if not skip(node):
all_nodes[id(node)] = (True, type_string)
for subnode in subnodes:
_tree_check(subnode, all_nodes, skip)
all_nodes[id(node)] = (False, type_string)
23 changes: 23 additions & 0 deletions tests/test_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools as ft
import gc
from typing import Any

import jax
Expand Down Expand Up @@ -252,3 +253,25 @@ def test_wrapped(getkey):
y = eqx.filter_vmap(eqx.nn.Linear(2, 2, key=getkey()))
x, y = eqx.filter((x, y), eqx.is_array)
jtu.tree_map(lambda x, y: x + y, x, y)


def test_tree_check_cache(getkey):
gc.collect()
has_been_checked = eqx._module._has_been_checked
num_checked = len(has_been_checked)
mlp = eqx.nn.MLP(2, 2, 2, 2, key=getkey())
# +4: one for `MLP`, and three for its `Linear` layers inside.
assert len(has_been_checked) == num_checked + 4
del mlp
gc.collect()
assert len(has_been_checked) == num_checked


def test_duplicate_layer_error(getkey):
class M(eqx.Module):
l1: eqx.nn.Linear
l2: eqx.nn.Linear

linear = eqx.nn.Linear(2, 2, key=getkey())
with pytest.raises(ValueError, match="As of Equinox v0.11.0"):
M(linear, linear)
4 changes: 2 additions & 2 deletions tests/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,9 @@ def __init__(self, test=lambda x: 2 * x) -> None:
def _transform(self, x):
return self.test(x)

a = SubComponent()
with pytest.raises(ValueError):
eqx.tree_check(a)
# Checking occurs when initialising the module.
SubComponent()


def test_tree_check_none():
Expand Down

0 comments on commit de5a982

Please sign in to comment.