Skip to content

Commit

Permalink
Merge pull request #38 from patrick-kidger/attn-convt-layernorm
Browse files Browse the repository at this point in the history
Attention, Transposed Convolutions, Embeddings, LayerNorm
  • Loading branch information
patrick-kidger authored Mar 15, 2022
2 parents 4c07080 + 97ed6e0 commit 3343e84
Show file tree
Hide file tree
Showing 17 changed files with 1,104 additions and 52 deletions.
16 changes: 16 additions & 0 deletions docs/_static/mathjax.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
window.MathJax = {
tex: {
inlineMath: [["\\(", "\\)"]],
displayMath: [["\\[", "\\]"]],
processEscapes: true,
processEnvironments: true
},
options: {
ignoreHtmlClass: ".*|",
processHtmlClass: "arithmatex"
}
};

document$.subscribe(() => {
MathJax.typesetPromise()
})
7 changes: 7 additions & 0 deletions docs/api/nn/attention.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Attention

::: equinox.nn.MultiheadAttention
selection:
members:
- __init__
- __call__
29 changes: 28 additions & 1 deletion docs/api/nn/conv.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Convolutional layers
# Convolutional

::: equinox.nn.Conv
selection:
members:
Expand All @@ -7,6 +8,14 @@

---

::: equinox.nn.ConvTranspose
selection:
members:
- __init__
- __call__

---

::: equinox.nn.Conv1d
selection:
members: false
Expand All @@ -22,3 +31,21 @@
::: equinox.nn.Conv3d
selection:
members: false

---

::: equinox.nn.ConvTranspose1d
selection:
members: false

---

::: equinox.nn.ConvTranspose2d
selection:
members: false

---

::: equinox.nn.ConvTranspose3d
selection:
members: false
7 changes: 7 additions & 0 deletions docs/api/nn/embedding.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Embeddings

::: equinox.nn.Embedding
selection:
members:
- __init__
- __call__
2 changes: 1 addition & 1 deletion docs/api/nn/linear.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Linear layers
# Linear

::: equinox.nn.Linear
selection:
Expand Down
7 changes: 7 additions & 0 deletions docs/api/nn/normalisation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Normalisation

::: equinox.nn.LayerNorm
selection:
members:
- __init__
- __call__
2 changes: 1 addition & 1 deletion docs/api/nn/rnn.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Recurrent layers
# Recurrent

::: equinox.nn.GRUCell
selection:
Expand Down
2 changes: 1 addition & 1 deletion equinox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@
from .update import apply_updates


__version__ = "0.2.1"
__version__ = "0.2.2"
105 changes: 99 additions & 6 deletions equinox/custom_types.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,108 @@
import inspect
import typing
from typing import Any
from typing import Generic, Tuple, TypeVar, Union

import jax
import jax.numpy as jnp


# Custom flag we set when generating documentation.
# We do a lot of custom hackery in here to produce nice-looking docs.
if getattr(typing, "GENERATING_DOCUMENTATION", False):
Array = "jax.numpy.ndarray"
PyTree = "PyTree"

def _item_to_str(item: Union[str, type, slice]) -> str:
if isinstance(item, slice):
if item.step is not None:
raise NotImplementedError
return _item_to_str(item.start) + ": " + _item_to_str(item.stop)
elif item is ...:
return "..."
elif inspect.isclass(item):
return item.__name__
else:
return repr(item)

def _maybe_tuple_to_str(
item: Union[str, type, slice, Tuple[Union[str, type, slice], ...]]
) -> str:
if isinstance(item, tuple):
if len(item) == 0:
# Explicit brackets
return "()"
else:
# No brackets
return ", ".join([_item_to_str(i) for i in item])
else:
return _item_to_str(item)

#
# First we have Generic versions of Array and PyTree.
#
# Crucially the __module__ and __qualname__ are overridden. This particular combo
# makes Python's typing module just use the __qualname__ as what is displayed in
# stringified type annotations.
# (For some strange reason typing uses a custom stringifiation algorithm, rather
# than just str(...) or repr(...).)
#
# c.f.
# https://github.com/python/cpython/blob/634984d7dbdd91e0a51a793eed4d870e139ae1e0/Lib/typing.py#L203 # noqa: E501
#
# Note that in general overriding __module__ can be a bit dangerous, and will break
# functionality in the inspect standard library.
#

_Annotation = TypeVar("_Annotation")

class _Array(Generic[_Annotation]):
pass

class _PyTree(Generic[_Annotation]):
pass

_Array.__module__ = "builtins"
_Array.__qualname__ = "Array"
_PyTree.__module__ = "builtins"
_PyTree.__qualname__ = "PyTree"

#
# Now we have Array and PyTree themselves. In order to get the desired behaviour in
# docs, we now pass in a type variable with the right __qualname__ (and __module__
# set to "builtins" as usual) that will render in the desired way.
#

class Array:
def __class_getitem__(cls, item):
class X:
pass

X.__module__ = "builtins"
X.__qualname__ = _maybe_tuple_to_str(item)
return _Array[X]

class PyTree:
def __class_getitem__(cls, item):
class X:
pass

X.__module__ = "builtins"
X.__qualname__ = _maybe_tuple_to_str(item)
return _PyTree[X]

# Same __module__ trick here again. (So that we get the correct display when
# doing `def f(x: Array)` as well as `def f(x: Array["dim"])`.
#
# Don't need to set __qualname__ as that's already correct.
Array.__module__ = "builtins"
PyTree.__module__ = "builtins"

else:
Array = jnp.ndarray
PyTree = Any

class Array:
def __class_getitem__(cls, item):
return Array

class PyTree:
def __class_getitem__(cls, item):
return PyTree


TreeDef = type(jax.tree_structure(0))
14 changes: 13 additions & 1 deletion equinox/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
from .attention import MultiheadAttention
from .composed import MLP, Sequential
from .conv import Conv, Conv1d, Conv2d, Conv3d
from .conv import (
Conv,
Conv1d,
Conv2d,
Conv3d,
ConvTranspose,
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
)
from .dropout import Dropout
from .embedding import Embedding
from .linear import Identity, Linear
from .normalisation import LayerNorm
from .rnn import GRUCell, LSTMCell
Loading

0 comments on commit 3343e84

Please sign in to comment.