From 0c7849fe84792feff173bd04bd6f5acc6d7be501 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Thu, 7 Sep 2023 07:53:23 -0700 Subject: [PATCH] Now running tree_check when initialising a Module --- equinox/_module.py | 41 ++++++++++++++++++++++++++++++++++++++++- equinox/_tree.py | 19 ++++++++++++++----- tests/test_module.py | 23 +++++++++++++++++++++++ tests/test_tree.py | 4 ++-- 4 files changed, 79 insertions(+), 8 deletions(-) diff --git a/equinox/_module.py b/equinox/_module.py index ed04ec91..3bad6cf8 100644 --- a/equinox/_module.py +++ b/equinox/_module.py @@ -14,7 +14,7 @@ from ._caches import internal_lru_caches from ._doc_utils import doc_repr from ._pretty_print import tree_pformat -from ._tree import tree_equal +from ._tree import tree_check_internal, tree_equal _P = ParamSpec("_P") @@ -110,6 +110,11 @@ def _not_magic(k: str) -> bool: _has_dataclass_init = weakref.WeakKeyDictionary() +_has_been_checked = weakref.WeakValueDictionary() + + +def _skip(node): + return isinstance(node, Module) and node is _has_been_checked.get(id(node), None) # Inherits from ABCMeta as a convenience for a common use-case. @@ -201,6 +206,40 @@ def __call__(cls, *args, **kwargs): else: setattr(self, field.name, converter(getattr(self, field.name))) object.__setattr__(self, "__class__", cls) + # Note that this only runs during the initial creation, and not during + # unflattening. + try: + tree_check_internal(self, _skip) + except ValueError as e: + raise ValueError( + "As of Equinox v0.11.0, `equinox.Module`s now validate that there " + "aren't any repeated layers inside a module. This is because this was " + "previously a common bug.\n" + "As an example, something like this:\n" + "```\n`" + "class MyModule(eqx.Module):\n" + " linear1: eqx.nn.Linear\n" + " linear2: eqx.nn.Linear\n" + "\n" + " def __init__(self, ...):\n" + " linear = eqx.nn.Linear(...)\n" + " self.linear1 = linear\n" + " self.linear2 = linear\n" + "```\n" + "resulted in two independent linear layers after a gradient update had " + "happened.\n" + "An exception is being thrown now as this error been detected.\n" + "If you intended to share the layer, then use the new functionality " + "`eqx.nn.Shared`. If you intended to have two duplicate copies, then " + "please instantiate two separate layers. If it's easier, you can also " + "clone an existing layer by doing\n" + "```\n" + "layer = ...\n" + "leaves, treedef = jax.tree_util.tree_flatten(layer)\n" + "clone_layer = jax.tree_util.tree_unflatten(treedef, leaves)\n" + "```" + ) from e + _has_been_checked[id(self)] = self return self diff --git a/equinox/_tree.py b/equinox/_tree.py index de1b1dc3..ef25e357 100644 --- a/equinox/_tree.py +++ b/equinox/_tree.py @@ -338,6 +338,14 @@ def is_leaf(node): return jtu.tree_flatten(pytree, is_leaf=is_leaf) +def tree_check_internal(pytree, skip) -> None: + """As `tree_check`, but can skips checking some nodes (typically those that have + alread been checked). + """ + all_nodes = {} + _tree_check(pytree, all_nodes, skip) + + def tree_check(pytree: Any) -> None: """Checks if the PyTree is well-formed: does it have no self-references, and does it have no duplicate layers. @@ -389,13 +397,13 @@ def tree_check(pytree: Any) -> None: A `ValueError` if the PyTree is not well-formed. """ all_nodes = {} - _tree_check(pytree, all_nodes) + _tree_check(pytree, all_nodes, skip=lambda _: False) _leaf_treedef = jtu.tree_structure(0) -def _tree_check(node, all_nodes): +def _tree_check(node, all_nodes, skip): subnodes, treedef = tree_flatten_one_level(node) # We allow duplicate leaves and empty containers, so don't raise an error with those if treedef != _leaf_treedef and treedef.num_leaves > 0: @@ -422,7 +430,8 @@ def _tree_check(node, all_nodes): except AttributeError: # AttributeError: in case we cannot get __name__ for some weird reason. type_string = "" - all_nodes[id(node)] = (True, type_string) - for subnode in subnodes: - _tree_check(subnode, all_nodes) + if not skip(node): + all_nodes[id(node)] = (True, type_string) + for subnode in subnodes: + _tree_check(subnode, all_nodes, skip) all_nodes[id(node)] = (False, type_string) diff --git a/tests/test_module.py b/tests/test_module.py index 0f426d31..f3ba8b88 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -1,4 +1,5 @@ import functools as ft +import gc from typing import Any import jax @@ -252,3 +253,25 @@ def test_wrapped(getkey): y = eqx.filter_vmap(eqx.nn.Linear(2, 2, key=getkey())) x, y = eqx.filter((x, y), eqx.is_array) jtu.tree_map(lambda x, y: x + y, x, y) + + +def test_tree_check_cache(getkey): + gc.collect() + has_been_checked = eqx._module._has_been_checked + num_checked = len(has_been_checked) + mlp = eqx.nn.MLP(2, 2, 2, 2, key=getkey()) + # +4: one for `MLP`, and three for its `Linear` layers inside. + assert len(has_been_checked) == num_checked + 4 + del mlp + gc.collect() + assert len(has_been_checked) == num_checked + + +def test_duplicate_layer_error(getkey): + class M(eqx.Module): + l1: eqx.nn.Linear + l2: eqx.nn.Linear + + linear = eqx.nn.Linear(2, 2, key=getkey()) + with pytest.raises(ValueError, match="As of Equinox v0.11.0"): + M(linear, linear) diff --git a/tests/test_tree.py b/tests/test_tree.py index 9f883704..0078d413 100644 --- a/tests/test_tree.py +++ b/tests/test_tree.py @@ -231,9 +231,9 @@ def __init__(self, test=lambda x: 2 * x) -> None: def _transform(self, x): return self.test(x) - a = SubComponent() with pytest.raises(ValueError): - eqx.tree_check(a) + # Checking occurs when initialising the module. + SubComponent() def test_tree_check_none():