Skip to content

Commit

Permalink
default to FA=on, adjust default block size for tpu
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Feb 14, 2024
1 parent 58b1de6 commit b22c0f7
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 16 deletions.
8 changes: 5 additions & 3 deletions src/levanter/models/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
from levanter.models.attention import AttentionMask, materialize_mask


# TODO: tune
BLOCK_SIZE = 128
BLOCK_SIZE = 1024


@named_call
Expand All @@ -37,7 +36,7 @@ def flash_attention(
dropout: float = 0.0,
inference: bool,
key: Optional[PRNGKeyArray] = None,
block_size: int = BLOCK_SIZE,
block_size: Optional[int] = None,
dtype: Optional[jnp.dtype] = None,
precision: PrecisionLike = None,
):
Expand All @@ -58,6 +57,9 @@ def flash_attention(
q = q.astype(dtype)
k = k.astype(dtype)

if block_size is None:
block_size = BLOCK_SIZE

# premultiply by 1/sqrt(d_k) for normal dot product attention
q = q / math.sqrt(float(q.axis_size(Key)))

Expand Down
4 changes: 2 additions & 2 deletions src/levanter/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ class Gpt2Config(HFCompatConfig):

use_bias: bool = True

use_flash_attention: bool = False # use flash attention. This is a pure jax impl, and is not faster than normal, but it scales to long sequence lengths
flash_attention_block_size: int = 1024
use_flash_attention: bool = True # use flash attention. This is a pure jax impl, and is not faster than normal, but it scales to long sequence lengths
flash_attention_block_size: Optional[int] = None

# Axes
Pos = property(lambda self: Axis(name="position", size=self.seq_len))
Expand Down
2 changes: 1 addition & 1 deletion src/levanter/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class LlamaConfig(HFCompatConfig):

# Attention-related config
upcast_attn: bool = False
use_flash_attention: bool = False
use_flash_attention: bool = True
flash_attention_block_size: Optional[int] = None

gradient_checkpointing: bool = True
Expand Down
2 changes: 1 addition & 1 deletion src/levanter/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class MistralConfig(LlamaConfig):

# Attention-related config
upcast_attn: bool = False
use_flash_attention: bool = False
use_flash_attention: bool = True
flash_attention_block_size: Optional[int] = None

gradient_checkpointing: bool = True
Expand Down
27 changes: 18 additions & 9 deletions tests/test_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
import haliax.nn as hnn

from levanter.models.attention import AttentionMask
from levanter.models.flash_attention import BLOCK_SIZE, flash_attention
from levanter.models.flash_attention import flash_attention


BLOCK_SIZE = 64


def test_flash_attention_acausal():
Expand All @@ -21,7 +24,7 @@ def test_flash_attention_acausal():
k = hax.random.normal(jrandom.PRNGKey(1), (KPos, Key))
v = hax.random.normal(jrandom.PRNGKey(2), (KPos, Key))

flash_out = flash_attention(QPos, KPos, Key, q, k, v, inference=True)
flash_out = flash_attention(QPos, KPos, Key, q, k, v, inference=True, block_size=BLOCK_SIZE)
hax_out = hnn.attention.dot_product_attention(KPos, Key, q, k, v)

assert hax_out.axes == flash_out.axes
Expand All @@ -39,7 +42,7 @@ def test_flash_attention_causal_mask():
k = hax.random.normal(jrandom.PRNGKey(1), (KPos, Key))
v = hax.random.normal(jrandom.PRNGKey(2), (KPos, Key))

flash_out = flash_attention(QPos, KPos, Key, q, k, v, inference=True, mask=mask)
flash_out = flash_attention(QPos, KPos, Key, q, k, v, inference=True, mask=mask, block_size=BLOCK_SIZE)
hax_out = hnn.attention.dot_product_attention(KPos, Key, q, k, v, mask=mask.materialize(QPos, KPos))

assert hax_out.axes == flash_out.axes
Expand All @@ -48,8 +51,8 @@ def test_flash_attention_causal_mask():

def test_grad_attention():
Key = hax.Axis("Key", 8)
QPos = hax.Axis("QPos", BLOCK_SIZE * 2)
KPos = hax.Axis("KPos", BLOCK_SIZE * 2)
QPos = hax.Axis("QPos", BLOCK_SIZE * 4)
KPos = hax.Axis("KPos", BLOCK_SIZE * 4)

mask = AttentionMask.causal()

Expand All @@ -68,7 +71,9 @@ def d_attn(qkv, fn):
return (x_out * x_out).sum().scalar()

hax_val, (hax_dq, hax_dk, hax_dv) = d_attn((q, k, v), hnn.attention.dot_product_attention)
fa_val, (fa_dq, fa_dk, fa_dv) = d_attn((q, k, v), functools.partial(flash_attention, QPos, inference=True))
fa_val, (fa_dq, fa_dk, fa_dv) = d_attn(
(q, k, v), functools.partial(flash_attention, QPos, inference=True, block_size=BLOCK_SIZE)
)

assert jnp.allclose(hax_val, fa_val, atol=1e-4, rtol=1e-4)
assert hax_dq.axes == fa_dq.axes
Expand Down Expand Up @@ -102,7 +107,9 @@ def d_attn(qkv, fn):
return (x_out * x_out).sum().scalar()

hax_val, (hax_dq, hax_dk, hax_dv) = d_attn((q, k, v), hnn.attention.dot_product_attention)
fa_val, (fa_dq, fa_dk, fa_dv) = d_attn((q, k, v), functools.partial(flash_attention, QPos, inference=True))
fa_val, (fa_dq, fa_dk, fa_dv) = d_attn(
(q, k, v), functools.partial(flash_attention, QPos, inference=True, block_size=BLOCK_SIZE)
)

assert jnp.allclose(hax_val, fa_val, atol=1e-4, rtol=1e-4)
assert hax_dq.axes == fa_dq.axes
Expand All @@ -127,8 +134,10 @@ def test_fa_dropout_does_something():

p_drop = 0.5

fa_with_dropout = functools.partial(flash_attention, inference=False, dropout=p_drop, key=jrandom.PRNGKey(3))
fa_without_dropout = functools.partial(flash_attention, inference=True)
fa_with_dropout = functools.partial(
flash_attention, inference=False, dropout=p_drop, key=jrandom.PRNGKey(3), block_size=BLOCK_SIZE
)
fa_without_dropout = functools.partial(flash_attention, inference=True, block_size=BLOCK_SIZE)

without_o = fa_without_dropout(QPos, KPos, Key, q, k, v, mask=mask)
with_o = fa_with_dropout(QPos, KPos, Key, q, k, v, mask=mask)
Expand Down

0 comments on commit b22c0f7

Please sign in to comment.