Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added RoPE Offset + Test #860

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@ examples/MNIST
examples/multipart_serialised.eqx
.python-version
.DS_Store
.ruff_cache
.pytest_cache
.venv
37 changes: 21 additions & 16 deletions equinox/nn/_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,18 +127,22 @@ def __call__(...):
def process_heads(
query_heads: Float[Array, "seq_length num_heads qk_size"],
key_heads: Float[Array, "seq_length num_heads qk_size"],
value_heads: Float[Array, "seq_length num_heads vo_size"]
value_heads: Float[Array, "seq_length num_heads vo_size"],
index: Int[Array, ""]
) -> tuple[
Float[Array, "seq_length num_heads qk_size"],
Float[Array, "seq_length num_heads qk_size"],
Float[Array, "seq_length num_heads vo_size"]
]:
query_heads = jax.vmap(self.rope_embeddings,
in_axes=1,
out_axes=1)(query_heads)
key_heads = jax.vmap(self.rope_embeddings,
in_axes=1,
out_axes=1)(key_heads)
# index is the autoregressive index of the current token
rope_partial = functools.partial(
rope_embeddings,
offset=index
)
query_heads = jax.vmap(rope_partial, in_axes=1, out_axes=1)
(query_heads)
key_heads = jax.vmap(rope_partial, in_axes=1, out_axes=1)
(key_heads)

return query_heads, key_heads, value_heads

Expand Down Expand Up @@ -178,14 +182,14 @@ def rotate_half(x: Float[Array, "seq_length embedding_size"]):

@staticmethod
def precompute_freqs_cis(
embedding_size: int, end: int, theta: float, dtype: Any
embedding_size: int, end: Int[ArrayLike, ""], theta: float, dtype: Any
) -> tuple[Float[Array, "end half_emb_size"], Float[Array, "end half_emb_size"]]:
freqs = 1.0 / (
theta
** (jnp.arange(0.0, embedding_size, 2)[jnp.newaxis, :] / embedding_size)
)

t = jnp.arange(float(end))
t = jnp.arange(float(end)) # type: ignore
freqs_outer = jnp.outer(t, freqs)

# we assign the type at the very end to minimize the loss of precision
Expand All @@ -195,12 +199,14 @@ def precompute_freqs_cis(
def __call__(
self,
x: Float[Array, "seq_length embedding_size"],
offset: Int[ArrayLike, ""] = 0,
*,
key: Optional[PRNGKeyArray] = None,
) -> Float[Array, "seq_length embedding_size"]:
"""**Arguments:**

- `x`: A JAX array of shape `(seq_length, embedding_size)`.
- `offset`: The offset to apply to the positional encoding.
- `key`: Ignored; provided for compatibility with the rest of the Equinox API.
(Keyword only argument.)

Expand All @@ -216,24 +222,24 @@ def __call__(
f"x.shape[-1] must match self.embedding_size, "
f"but {x.shape[-1]} != {self.embedding_size}"
)

with jax.ensure_compile_time_eval():
min_required_seq_len = offset + seq_len # pyright: ignore
cache_key = (embedding_size, self.dtype)
if cache_key not in internal_rope_embedding_cache:
internal_rope_embedding_cache[cache_key] = self.precompute_freqs_cis(
embedding_size, seq_len, self.theta, self.dtype
embedding_size, min_required_seq_len, self.theta, self.dtype
)

freqs_cos, freqs_sin = internal_rope_embedding_cache[cache_key]
freqs_seq_len, _ = freqs_cos.shape
if seq_len > freqs_seq_len:
if min_required_seq_len > freqs_seq_len: # pyright: ignore
internal_rope_embedding_cache[cache_key] = self.precompute_freqs_cis(
embedding_size, seq_len, self.theta, self.dtype
embedding_size, min_required_seq_len, self.theta, self.dtype
)
freqs_cos, freqs_sin = internal_rope_embedding_cache[cache_key]

freqs_cos = freqs_cos[:seq_len]
freqs_sin = freqs_sin[:seq_len]
freqs_cos = jax.lax.dynamic_slice_in_dim(freqs_cos, offset, seq_len)
freqs_sin = jax.lax.dynamic_slice_in_dim(freqs_sin, offset, seq_len)

freqs_cos = jnp.tile(freqs_cos, (1, 2))
freqs_sin = jnp.tile(freqs_sin, (1, 2))
Expand All @@ -255,7 +261,6 @@ def __call__(


RotaryPositionalEmbedding.__init__.__doc__ = """**Arguments:**

- `embedding_size`: Size of each embedding vector. Must be non-negative and even.
- `theta`: The base frequency for the sinusoidal functions used in positional encoding.
Specifies how quickly the inner-product will decay with relative distance between
Expand Down
24 changes: 21 additions & 3 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ def test_mlp_learnt_activation():
key=jrandom.PRNGKey(5678),
)
x = jnp.array([0.5, 0.7])
assert mlp.activation.negative_slope.shape == (2, 8)
assert mlp.final_activation.negative_slope.shape == (5,)
assert mlp.activation.negative_slope.shape == (2, 8) # pyright: ignore
assert mlp.final_activation.negative_slope.shape == (5,) # pyright: ignore

@eqx.filter_jit
@eqx.filter_grad
Expand Down Expand Up @@ -1353,13 +1353,14 @@ def test_prelu(getkey):

def test_rope_embeddings_shapes(getkey):
embedding_size = 32
rope_embeddings = eqx.nn.RotaryPositionalEmbedding(embedding_size)

n_heads = 4
seq_length = 8
query_size = 32
key_size = 32

rope_embeddings = eqx.nn.RotaryPositionalEmbedding(embedding_size)

query_heads = jax.random.normal(
key=getkey(), shape=(seq_length, n_heads, query_size)
)
Expand Down Expand Up @@ -1477,3 +1478,20 @@ def test_rope_embeddings_values():
jnp.allclose(res.astype(jnp.float32), expected_values, rtol=1e-2)
and res.dtype == jnp.float16
)


def test_rope_with_offset():
embedding_size = 2

rotary_emb = eqx.nn.RotaryPositionalEmbedding(embedding_size=embedding_size)
rotary_emb = eqx.filter_jit(rotary_emb)

out1 = rotary_emb(jnp.ones(shape=(2, embedding_size)), offset=1)
out2 = rotary_emb(jnp.ones(shape=(3, embedding_size)), offset=2)

assert jnp.allclose(out1[1], out2[0])

out3 = rotary_emb(jnp.ones(shape=(2, embedding_size)), offset=1)
out4 = rotary_emb(jnp.ones(shape=(3, embedding_size)), offset=1)

assert jnp.allclose(out3, out4[:2])
Loading