-
-
Notifications
You must be signed in to change notification settings - Fork 142
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added support for sharing layers between different parts of a model.
- Loading branch information
1 parent
c42bff5
commit b8e3165
Showing
5 changed files
with
292 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Sharing layers | ||
|
||
::: equinox.nn.Shared | ||
selection: | ||
members: | ||
- __init__ | ||
- __call__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
from collections.abc import Callable | ||
|
||
from jaxtyping import PyTree | ||
|
||
from .._eval_shape import filter_eval_shape | ||
from .._module import Module | ||
from .._tree import tree_at, tree_equal | ||
|
||
|
||
class SharedNode: | ||
"""Placeholder value for nodes that have been removed by `eqx.nn.Shared`.""" | ||
|
||
def __repr__(self): | ||
return "SharedNode" | ||
|
||
|
||
class Shared(Module): | ||
"""Used to tie together multiple nodes across a PyTree. | ||
Note that Equinox modules are Py**Trees** -- so the same layer, appearing in two | ||
difference parts of the tree, will be treated as two copies of this layer. For | ||
example, | ||
```python | ||
class SubModel(eqx.Module): | ||
linear: eqx.nn.Linear | ||
class Model(eqx.Module): | ||
linear: eqx.nn.Linear | ||
submodel: SubModel | ||
def __init__(self): | ||
linear = eqx.nn.Linear(...) | ||
self.linear = linear | ||
self.submodel = SubModel(linear) | ||
``` | ||
is used to declare `model.linear` and `model.submodel.linear` as two separate | ||
layers. They will start with the same initial parameter values, and then update | ||
independently during training. | ||
For when we really do want to share layers or weights across different parts of a | ||
model, then `eqx.nn.Shared` exists as a way to easily express this in the PyTree | ||
paradigm. | ||
!!! Example | ||
It is common in many language models to have an initial embedding matrix at the | ||
start, and then to reuse this as the weight of the final linear transformation. | ||
```python | ||
import equinox as eqx | ||
import jax.numpy as jnp | ||
from jaxtyping import Array, Int | ||
class LanguageModel(eqx.Module): | ||
shared: eqx.nn.Shared | ||
def __init__(self): | ||
embedding = eqx.nn.Embedding(...) | ||
linear = eqx.nn.Linear(...) | ||
# These two weights will now be tied together. | ||
where = lambda embed_and_lin: embed_and_lin[1].weight | ||
get = lambda embed_and_lin: embed_and_lin[0].weight | ||
self.shared = eqx.nn.Shared((embedding, linear), where, get) | ||
def __call__(self, tokens: Int[Array, "sequence"]): | ||
# Expand back out so we can evaluate these layers. | ||
embedding, linear = self.shared() | ||
assert embedding.weight is linear.weight # same parameter! | ||
# Now go ahead and evaluate your language model. | ||
values = jax.vmap(embedding)(tokens) | ||
... # other layers, probably | ||
return jax.vmap(linear)(values) | ||
``` | ||
_(Side note: you will sometimes see some authors referring to transposing | ||
the embedding matrix prior to the final linear layer. This is because some | ||
other libraries store the weight matrices of linear layers the other way | ||
around. If that had been necessary here then we could have done it with | ||
`get = lambda embed_and_lin: jnp.transpose(embed_and_lin[0].weight)`.)_ | ||
""" | ||
|
||
pytree: PyTree | ||
where: Callable | ||
get: Callable | ||
|
||
def __init__(self, pytree: PyTree, where: Callable, get: Callable): | ||
"""**Arguments:** | ||
- `pytree`: The PyTree to share some nodes across. | ||
- `where`: a function specifying either a single node, or a sequence of nodes, | ||
as with `eqx.tree_at(where, pytree, ...)`. | ||
- `get`: a function, which when evaluated on `pytree`, returns either a single | ||
value (if `where` does), or a sequence of values (if `where` does, and in | ||
this case this must be a sequence of the same length as `where`). | ||
The node(s) of `get(pytree)` and the corresponding value(s) of `where(pytree)` | ||
will be tied together. | ||
!!! info | ||
To explain how this works. The implementation is just: | ||
```python | ||
class Shared(eqx.Module): | ||
pytree: PyTree | ||
where: Callable | ||
get: Callable | ||
def __init__(self, pytree, where, get): | ||
# `0` is just some dummy value | ||
self.pytree = eqx.tree_at(where, pytree, replace_fn=lambda _: 0) | ||
self.where = where | ||
self.get = get | ||
def __call__(self): | ||
return eqx.tree_at(self.where, self.pytree, self.get(self.pytree)) | ||
``` | ||
so that at `__init__` time, the duplicate nodes specified in `where` are | ||
removed from the PyTree. We no longer have a separate copy updating during | ||
training. | ||
And then at `__call__` time, references to the values returned by | ||
`get(pytree)` are put in their place. We end up with a pytree of the same | ||
structure as what we started with, which we can now use (evaluate as a | ||
layer etc.) as normal. | ||
!!! tip | ||
If you need to apply any transform (e.g. transposing a matrix), then this | ||
can be done as part of `get`. For example, | ||
`get = lambda pair: jnp.transpose(pair[1].weight)`. | ||
""" | ||
|
||
source_struct = filter_eval_shape(get, pytree) | ||
dest_struct = filter_eval_shape(where, pytree) | ||
if tree_equal(source_struct, dest_struct) is not True: | ||
raise ValueError( | ||
"Every node being shared together must have the same pytree " | ||
"structure, shape+dtype of arrays, etc., as each other. Got:\n" | ||
f"{source_struct}\n" | ||
"and\n" | ||
f"{dest_struct}" | ||
) | ||
self.pytree = tree_at(where, pytree, replace_fn=lambda _: SharedNode()) | ||
self.where = where | ||
self.get = get | ||
|
||
def __call__(self): | ||
"""**Arguments:** | ||
None. | ||
**Returns:** | ||
A PyTree of the same structure as the original `pytree`, with `get(pytree)` in | ||
the place of the nodes at `where(pytree)`. | ||
""" | ||
return tree_at(self.where, self.pytree, self.get(self.pytree)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import jax | ||
import jax.numpy as jnp | ||
import jax.random as jr | ||
import pytest | ||
from jaxtyping import Array, Float, Int | ||
|
||
import equinox as eqx | ||
|
||
|
||
def test_shared_array(getkey): | ||
class MyModule(eqx.Module): | ||
shared: eqx.nn.Shared | ||
|
||
def __init__(self): | ||
embedding = eqx.nn.Embedding( | ||
num_embeddings=3, embedding_size=4, key=getkey() | ||
) | ||
head = eqx.nn.Linear(4, 3, key=getkey()) | ||
where = lambda pair: pair[1].weight | ||
get = lambda pair: pair[0].weight | ||
self.shared = eqx.nn.Shared((embedding, head), where, get) | ||
|
||
def __call__(self, token: Int[Array, ""]): | ||
nonlocal called | ||
called = True | ||
embedding, head = self.shared() | ||
assert embedding.weight is head.weight | ||
return head(embedding(token)) | ||
|
||
called = False | ||
module = MyModule() | ||
module(jnp.array(0)) | ||
assert called | ||
|
||
|
||
# We share a non-leaf node | ||
def test_shared_node(getkey): | ||
class MyModule(eqx.Module): | ||
shared: eqx.nn.Shared | ||
|
||
def __init__(self): | ||
attention = eqx.nn.MultiheadAttention( | ||
num_heads=3, query_size=12, key=getkey() | ||
) | ||
my_proj = eqx.nn.Linear(12, 12, use_bias=False, key=getkey()) | ||
where = lambda pair: pair[1].key_proj | ||
get = lambda pair: pair[0] | ||
self.shared = eqx.nn.Shared((my_proj, attention), where, get) | ||
|
||
def __call__(self, x: Float[Array, "seq 12"]): | ||
nonlocal called | ||
called = True | ||
my_proj, attention = self.shared() | ||
eq = eqx.tree_equal(my_proj, attention.key_proj) | ||
x = attention(x, x, x) | ||
out = jax.vmap(my_proj)(x) | ||
return out, eq | ||
|
||
called = False | ||
module = MyModule() | ||
x = jr.normal(getkey(), (5, 12)) | ||
|
||
@eqx.filter_jit | ||
@eqx.filter_grad(has_aux=True) | ||
def f(module, x): | ||
out, eq = module(x) | ||
return jnp.sum(out), eq | ||
|
||
d_module, eq = f(module, x) | ||
assert called | ||
assert eq | ||
module = eqx.apply_updates(module, d_module) | ||
d_module, eq = f(module, x) | ||
assert eq | ||
module = eqx.apply_updates(module, d_module) | ||
|
||
|
||
def test_mismatched_structure(getkey): | ||
x = jr.normal(getkey(), (3, 4)) | ||
y = jr.normal(getkey(), (4, 3)) | ||
with pytest.raises(ValueError, match="Every node being shared together"): | ||
eqx.nn.Shared((x, y), lambda pair: pair[0], lambda pair: pair[1]) | ||
|
||
|
||
def test_multi_shared(getkey): | ||
class MyModule(eqx.Module): | ||
shared: eqx.nn.Shared | ||
|
||
def __init__(self): | ||
my_proj = eqx.nn.Linear(12, 12, use_bias=False, key=getkey()) | ||
attention = eqx.nn.MultiheadAttention( | ||
num_heads=3, query_size=12, key=getkey() | ||
) | ||
where = lambda pair: (pair[1].key_proj, pair[1].query_proj.weight) | ||
get = lambda pair: (pair[0], pair[0].weight + 1) | ||
self.shared = eqx.nn.Shared((my_proj, attention), where, get) | ||
|
||
def __call__(self, x: Float[Array, "seq 12"]): | ||
nonlocal called | ||
called = True | ||
my_proj, attention = self.shared() | ||
eq1 = eqx.tree_equal(my_proj, attention.key_proj) | ||
eq2 = (my_proj.weight + 1 == attention.query_proj.weight).all() | ||
x = attention(x, x, x) | ||
out = jax.vmap(my_proj)(x) | ||
eq = eq1 & eq2 | ||
return out, eq | ||
|
||
called = False | ||
module = MyModule() | ||
x = jr.normal(getkey(), (5, 12)) | ||
|
||
@eqx.filter_jit | ||
@eqx.filter_grad(has_aux=True) | ||
def f(module, x): | ||
out, eq = module(x) | ||
return jnp.sum(out), eq | ||
|
||
d_module, eq = f(module, x) | ||
assert called | ||
assert eq | ||
module = eqx.apply_updates(module, d_module) | ||
d_module, eq = f(module, x) | ||
assert eq | ||
module = eqx.apply_updates(module, d_module) |