Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added RoPE Offset + Test #860

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

Artur-Galstyan
Copy link
Contributor

This PR follows #799 and adds the possibility for an offset in RoPE.

I've rebased my changes on the latest main branch and added one more test.

For your convenience, here's some code to test it with MHA:

Toggle me
import functools
import equinox as eqx
import jax
import jax.numpy as jnp
from equinox.nn._attention import MultiheadAttention
from equinox.nn._embedding import RotaryPositionalEmbedding
from jaxtyping import Array, Float, Int


embedding_size = 32
max_seq_length = 8
seq_length = 4
num_heads = 2
query_size = 64


class TransformerBlock(eqx.Module):
    rope_embeddings: RotaryPositionalEmbedding
    mha_attention: MultiheadAttention

    def __init__(self, embedding_size, max_seq_length, num_heads, query_size):
        self.rope_embeddings = RotaryPositionalEmbedding(embedding_size, max_seq_length)
        self.mha_attention = MultiheadAttention(
            num_heads=num_heads, query_size=query_size, key=jax.random.key(0)
        )

    def __call__(self, query, key_, value, index):
        def process_heads(
            query_heads: Float[Array, "seq_length num_heads qk_size"],
            key_heads: Float[Array, "seq_length num_heads qk_size"],
            value_heads: Float[Array, "seq_length num_heads vo_size"],
            index: Int[Array, ""],
        ) -> tuple[
            Float[Array, "seq_length num_heads qk_size"],
            Float[Array, "seq_length num_heads qk_size"],
            Float[Array, "seq_length num_heads vo_size"],
        ]:
            # index is the autoregressive index of the current token
            rope_partial = functools.partial(self.rope_embeddings, offset=index)
            query_heads = jax.vmap(rope_partial, in_axes=1, out_axes=1)(query_heads)
            key_heads = jax.vmap(rope_partial, in_axes=1, out_axes=1)(key_heads)

            return query_heads, key_heads, value_heads

        x = self.mha_attention(
            query=query,
            key_=key_,
            value=value,
            process_heads=functools.partial(process_heads, index=index),
        )

        return x


transformer_block = TransformerBlock(
    embedding_size, max_seq_length, num_heads, query_size
)
transformer_block = eqx.filter_jit(transformer_block)


q = jnp.ones(shape=(seq_length, query_size))
k = jnp.ones(shape=(seq_length, query_size))
v = jnp.ones(shape=(seq_length, query_size))

out = transformer_block(q, k, v, 0)
out = transformer_block(q, k, v, 1)
out = transformer_block(q, k, v, 2)
out = transformer_block(q, k, v, 3)
out = transformer_block(q, k, v, 4)
out = transformer_block(q, k, v, 5)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant