Skip to content

Commit

Permalink
Added strict modules, for enforcing the abstract/final design pattern.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Sep 12, 2023
1 parent a08fcbd commit a09da95
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 11 deletions.
6 changes: 3 additions & 3 deletions docs/pattern.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ The following is a very useful design pattern. It's not mandatory, but it comes

Finally, we should never re-override a method. Once a subclass implements a method, that's it.

This idea is very simple. Now, let's take a deep dive on why this is such a neat pattern. :)
This idea is very simple. Now, let's take a deep dive on why this is such a neat pattern, and how Equinox offers special tools to support this.

## Level 1: Abstract base classes (ABCs) as interfaces

Expand Down Expand Up @@ -116,7 +116,7 @@ class CubicInterpolation(AbstractPolynomialInterpolation):
coeffs = ... # some implementation
super().__init__(coeffs)
```
but once you have multiple classes involved, then splitting up your initialisation like this very quickly becomes far less readable. (And a reliable source of bugs.) Overall, we mandate that `__init__` methods only appear once, on our final concrete classes.
but once you have multiple classes involved, then splitting up your initialisation like this very quickly becomes far less readable. (And a reliable source of bugs.) Overall, we mandate that `__init__` methods and (non-abstract) fields may only be defined on concrete classes. Equinox supports checking this via a `strict=True` flag, passes as `class Foo(eqx.Module, strict=True)`.

## Level 3: implement methods precisely once, and concrete-means-final

Expand All @@ -127,7 +127,7 @@ In practice, we argue that's a good idea! This rule means that when you see code
def foo(interp: AbstractPolynomialInterpolation)
... = interp(...)
```
you know that it is calling `AbstractPolynomialInterpolation.__call__`, and not anything else. This is great for code readability.
you know that it is calling `AbstractPolynomialInterpolation.__call__`, and not anything else. This is great for code readability. Once again, this may be checked via a `strict=True` flag, passed as `class Foo(eqx.Module, strict=True)`.

If we assume this, then we now find ourselves arriving at a conclusion: concrete means final. That is, once we have a concrete class (every abstract method/attribute defined in our ABCs is now overriden with an implementation, so we can instantiate this class), then it is now final (we're not allowed to re-override things, so subclassing is pointless).

Expand Down
58 changes: 51 additions & 7 deletions equinox/_module.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
import dataclasses
import functools as ft
import inspect
Expand Down Expand Up @@ -142,15 +143,40 @@ def _not_magic(k: str) -> bool:
_has_dataclass_init = weakref.WeakKeyDictionary()


_dummy_abstract = abc.abstractmethod(lambda self: 1)


# Inherits from ABCMeta as a convenience for a common use-case.
# It's not a feature we use ourselves.
class _ModuleMeta(ABCMeta): # pyright: ignore
def __new__(mcs, name, bases, dict_, **kwargs): # pyright: ignore
dict_ = {
k: _wrap_method(v) if _not_magic(k) and inspect.isfunction(v) else v
for k, v in dict_.items()
}
def __new__(
mcs, name, bases, dict_, /, strict: bool = False, **kwargs
): # pyright: ignore
if strict:
for base in bases:
if not issubclass(base, Module):
raise TypeError(
"Strict `eqx.Module`s must only inherit from subclasses of "
"`eqx.Module`."
)
cls = super().__new__(mcs, name, bases, dict_, **kwargs)
for k, v in cls.__dict__.items():
if _not_magic(k) and inspect.isfunction(v):
setattr(cls, k, _wrap_method(v))
if strict:
if not getattr(v, "__isabstractmethod__", False):
for base in bases:
old_v = getattr(base, k, _dummy_abstract)
if not inspect.isfunction(old_v):
raise TypeError(
"Strict `eqx.Module`s cannot override non-methods "
"with methods."
)
if not getattr(old_v, "__isabstractmethod__", False):
raise TypeError(
"Strict `eqx.Module`s cannot override concrete "
"methods."
)
# Do override subclasses' dataclass-__init__-s. (None of which call super, so
# they must be overriden.)
# Don't override custom __init__'s, which leads to poor ergonomics:
Expand All @@ -174,6 +200,21 @@ def __new__(mcs, name, bases, dict_, **kwargs): # pyright: ignore
cls = dataclass(eq=False, repr=False, frozen=True, init=_init)(
cls # pyright: ignore
)
if strict:
if (
len(cls.__abstractmethods__) > 0
or len(cls.__abstractvars__) > 0
or len(cls.__abstractclassvars__) > 0
):
if not _init:
raise TypeError(
"Strict `eqx.Module`s cannot have `__init__` methods."
)
if len(dataclasses.fields(cls)) > 0:
raise TypeError(
"Strict `eqx.Module`s cannot have fields. (You probably meant "
"to mark them as `eqx.AbstractVar[...]` instead.)"
)
# must happen after `dataclass(...)` we use this in `__getattribute__` to avoid
# making any property(def __wrapped__) visible until then. We want to be able to
# support property(def __wrapped__) for the sake of classes whose instances are
Expand Down Expand Up @@ -450,14 +491,17 @@ def __call__(self, x):
because `self` is just a PyTree. Unlike most other neural network libraries,
you can mix Equinox and native JAX without any difficulties at all.
!!! tip
!!! tip "For fans of strong typing."
Equinox modules are all [ABCs](https://docs.python.org/3/library/abc.html)
by default. This means you can use
[`abc.abstractmethod`](https://docs.python.org/3/library/abc.html#abc.abstractmethod).
You can also create abstract instance attributes or abstract class
attributes, see [`equinox.AbstractVar`][] and
[`equinox.AbstractClassVar`][].
[`equinox.AbstractClassVar`][]. Finally, optional Rust/Julia-like type-checking
may be enabled by passing `strict=True`, e.g.
`class Foo(eqx.Module, strict=True)`; see [this guide](../../../pattern/) for
the technical details.
""" # noqa: E501

def __hash__(self):
Expand Down
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ nav:
- 'api/pretty-printing.md'
- 'api/serialisation.md'
- Misc:
- 'faq.md'
- 'tricks.md'
- 'pattern.md'
- 'citation.md'
- 'faq.md'
80 changes: 80 additions & 0 deletions tests/test_module.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
import dataclasses
import functools as ft
from collections.abc import Callable
Expand Down Expand Up @@ -460,3 +461,82 @@ def __check_init__(self):

with pytest.raises(dataclasses.FrozenInstanceError):
A(1)


def test_strict_noerrors():
class Abstract(eqx.Module, strict=True):
@abc.abstractmethod
def foo(self, x):
pass

class Concrete1(Abstract, strict=True):
def foo(self, x):
return x + 1

class Concrete2(Abstract):
def foo(self, x):
return x + 1


def test_strict_non_module_base():
class NotAModule:
pass

with pytest.raises(TypeError, match="subclasses of `eqx.Module`"):

class MyModule(eqx.Module, NotAModule, strict=True):
pass


def test_strict_method_reoverride():
class A(eqx.Module, strict=True):
@abc.abstractmethod
def foo(self, x):
pass

class B(A, strict=True):
def foo(self, x):
pass

with pytest.raises(TypeError, match="concrete methods"):

class C(B, strict=True):
def foo(self, x):
pass


def test_strict_init():
with pytest.raises(TypeError, match="__init__"):

class Abstract(eqx.Module, strict=True):
def __init__(self):
pass

@abc.abstractmethod
def foo(self):
pass


def test_strict_fields():
class Abstract1(eqx.Module, strict=True):
bar: eqx.AbstractVar[int]

@abc.abstractmethod
def foo(self):
pass

class Abstract2(eqx.Module, strict=True):
bar: eqx.AbstractClassVar[int]

@abc.abstractmethod
def foo(self):
pass

with pytest.raises(TypeError, match="fields"):

class Abstract3(eqx.Module, strict=True):
bar: int

@abc.abstractmethod
def foo(self):
pass

0 comments on commit a09da95

Please sign in to comment.