-
-
Notifications
You must be signed in to change notification settings - Fork 142
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #38 from patrick-kidger/attn-convt-layernorm
Attention, Transposed Convolutions, Embeddings, LayerNorm
- Loading branch information
Showing
17 changed files
with
1,104 additions
and
52 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
window.MathJax = { | ||
tex: { | ||
inlineMath: [["\\(", "\\)"]], | ||
displayMath: [["\\[", "\\]"]], | ||
processEscapes: true, | ||
processEnvironments: true | ||
}, | ||
options: { | ||
ignoreHtmlClass: ".*|", | ||
processHtmlClass: "arithmatex" | ||
} | ||
}; | ||
|
||
document$.subscribe(() => { | ||
MathJax.typesetPromise() | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Attention | ||
|
||
::: equinox.nn.MultiheadAttention | ||
selection: | ||
members: | ||
- __init__ | ||
- __call__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Embeddings | ||
|
||
::: equinox.nn.Embedding | ||
selection: | ||
members: | ||
- __init__ | ||
- __call__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
# Linear layers | ||
# Linear | ||
|
||
::: equinox.nn.Linear | ||
selection: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Normalisation | ||
|
||
::: equinox.nn.LayerNorm | ||
selection: | ||
members: | ||
- __init__ | ||
- __call__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
# Recurrent layers | ||
# Recurrent | ||
|
||
::: equinox.nn.GRUCell | ||
selection: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,4 +23,4 @@ | |
from .update import apply_updates | ||
|
||
|
||
__version__ = "0.2.1" | ||
__version__ = "0.2.2" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,108 @@ | ||
import inspect | ||
import typing | ||
from typing import Any | ||
from typing import Generic, Tuple, TypeVar, Union | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
|
||
|
||
# Custom flag we set when generating documentation. | ||
# We do a lot of custom hackery in here to produce nice-looking docs. | ||
if getattr(typing, "GENERATING_DOCUMENTATION", False): | ||
Array = "jax.numpy.ndarray" | ||
PyTree = "PyTree" | ||
|
||
def _item_to_str(item: Union[str, type, slice]) -> str: | ||
if isinstance(item, slice): | ||
if item.step is not None: | ||
raise NotImplementedError | ||
return _item_to_str(item.start) + ": " + _item_to_str(item.stop) | ||
elif item is ...: | ||
return "..." | ||
elif inspect.isclass(item): | ||
return item.__name__ | ||
else: | ||
return repr(item) | ||
|
||
def _maybe_tuple_to_str( | ||
item: Union[str, type, slice, Tuple[Union[str, type, slice], ...]] | ||
) -> str: | ||
if isinstance(item, tuple): | ||
if len(item) == 0: | ||
# Explicit brackets | ||
return "()" | ||
else: | ||
# No brackets | ||
return ", ".join([_item_to_str(i) for i in item]) | ||
else: | ||
return _item_to_str(item) | ||
|
||
# | ||
# First we have Generic versions of Array and PyTree. | ||
# | ||
# Crucially the __module__ and __qualname__ are overridden. This particular combo | ||
# makes Python's typing module just use the __qualname__ as what is displayed in | ||
# stringified type annotations. | ||
# (For some strange reason typing uses a custom stringifiation algorithm, rather | ||
# than just str(...) or repr(...).) | ||
# | ||
# c.f. | ||
# https://github.com/python/cpython/blob/634984d7dbdd91e0a51a793eed4d870e139ae1e0/Lib/typing.py#L203 # noqa: E501 | ||
# | ||
# Note that in general overriding __module__ can be a bit dangerous, and will break | ||
# functionality in the inspect standard library. | ||
# | ||
|
||
_Annotation = TypeVar("_Annotation") | ||
|
||
class _Array(Generic[_Annotation]): | ||
pass | ||
|
||
class _PyTree(Generic[_Annotation]): | ||
pass | ||
|
||
_Array.__module__ = "builtins" | ||
_Array.__qualname__ = "Array" | ||
_PyTree.__module__ = "builtins" | ||
_PyTree.__qualname__ = "PyTree" | ||
|
||
# | ||
# Now we have Array and PyTree themselves. In order to get the desired behaviour in | ||
# docs, we now pass in a type variable with the right __qualname__ (and __module__ | ||
# set to "builtins" as usual) that will render in the desired way. | ||
# | ||
|
||
class Array: | ||
def __class_getitem__(cls, item): | ||
class X: | ||
pass | ||
|
||
X.__module__ = "builtins" | ||
X.__qualname__ = _maybe_tuple_to_str(item) | ||
return _Array[X] | ||
|
||
class PyTree: | ||
def __class_getitem__(cls, item): | ||
class X: | ||
pass | ||
|
||
X.__module__ = "builtins" | ||
X.__qualname__ = _maybe_tuple_to_str(item) | ||
return _PyTree[X] | ||
|
||
# Same __module__ trick here again. (So that we get the correct display when | ||
# doing `def f(x: Array)` as well as `def f(x: Array["dim"])`. | ||
# | ||
# Don't need to set __qualname__ as that's already correct. | ||
Array.__module__ = "builtins" | ||
PyTree.__module__ = "builtins" | ||
|
||
else: | ||
Array = jnp.ndarray | ||
PyTree = Any | ||
|
||
class Array: | ||
def __class_getitem__(cls, item): | ||
return Array | ||
|
||
class PyTree: | ||
def __class_getitem__(cls, item): | ||
return PyTree | ||
|
||
|
||
TreeDef = type(jax.tree_structure(0)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,17 @@ | ||
from .attention import MultiheadAttention | ||
from .composed import MLP, Sequential | ||
from .conv import Conv, Conv1d, Conv2d, Conv3d | ||
from .conv import ( | ||
Conv, | ||
Conv1d, | ||
Conv2d, | ||
Conv3d, | ||
ConvTranspose, | ||
ConvTranspose1d, | ||
ConvTranspose2d, | ||
ConvTranspose3d, | ||
) | ||
from .dropout import Dropout | ||
from .embedding import Embedding | ||
from .linear import Identity, Linear | ||
from .normalisation import LayerNorm | ||
from .rnn import GRUCell, LSTMCell |
Oops, something went wrong.