Skip to content

Commit

Permalink
Causal flash (#469)
Browse files Browse the repository at this point in the history
* take advantage of block causal structure in flash attention?

* default to FA=on, adjust default block size for tpu

* fix tests

* missed a spot

* ok actually fix tests maybe

* dumb
  • Loading branch information
dlwh authored Feb 14, 2024
1 parent 0160115 commit 936cf82
Show file tree
Hide file tree
Showing 15 changed files with 98 additions and 39 deletions.
40 changes: 29 additions & 11 deletions src/levanter/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,7 @@ def dot_product_attention(
raise ValueError("QPos and KPos must be different")

if use_flash:
from levanter.models.flash_attention import BLOCK_SIZE, flash_attention

if flash_block_size is None:
flash_block_size = BLOCK_SIZE
from levanter.models.flash_attention import flash_attention

return flash_attention(
QPos,
Expand All @@ -79,14 +76,35 @@ def dot_product_attention(
precision=precision,
)
else:
QPos = query.resolve_axis(QPos)
KPos = key.resolve_axis(KPos)
m = materialize_mask(mask, QPos, KPos)
weights = haliax.nn.attention.dot_product_attention_weights(
Key, KPos, query, key, mask=m, bias=bias, attention_dtype=attention_dtype, precision=precision
return simple_attention_with_dropout(
QPos, KPos, Key, query, key, value, mask, bias, inference, dropout, attention_dtype, precision, prng=prng
)
weights = haliax.nn.dropout(weights, dropout, key=prng, inference=inference)
return haliax.dot(KPos, weights, value)


def simple_attention_with_dropout(
QPos: Axis,
KPos: Axis,
Key: Axis,
query: NamedArray,
key: NamedArray,
value: NamedArray,
mask: Optional[Union[NamedArray, "AttentionMask"]] = None,
bias: Optional[NamedArray] = None,
inference: bool = False,
dropout: float = 0.0,
attention_dtype: Optional[jnp.dtype] = None,
precision: PrecisionLike = None,
*,
prng: Optional[PRNGKeyArray] = None,
):
QPos = query.resolve_axis(QPos)
KPos = key.resolve_axis(KPos)
m = materialize_mask(mask, QPos, KPos)
weights = haliax.nn.attention.dot_product_attention_weights(
Key, KPos, query, key, mask=m, bias=bias, attention_dtype=attention_dtype, precision=precision
)
weights = haliax.nn.dropout(weights, dropout, key=prng, inference=inference)
return haliax.dot(KPos, weights, value)


class AttentionMask(eqx.Module):
Expand Down
37 changes: 27 additions & 10 deletions src/levanter/models/flash_attention.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# cf https://github.com/lucidrains/flash-attention-jax
# cf https://tridao.me/publications/flash2/flash2.pdf
# cf https://arxiv.org/pdf/2205.14135.pdf
import math
from typing import Optional, Tuple

import equinox
Expand All @@ -18,8 +19,7 @@
from levanter.models.attention import AttentionMask, materialize_mask


# TODO: tune
BLOCK_SIZE = 128
BLOCK_SIZE = 1024


@named_call
Expand All @@ -36,7 +36,7 @@ def flash_attention(
dropout: float = 0.0,
inference: bool,
key: Optional[PRNGKeyArray] = None,
block_size: int = BLOCK_SIZE,
block_size: Optional[int] = None,
dtype: Optional[jnp.dtype] = None,
precision: PrecisionLike = None,
):
Expand All @@ -57,12 +57,22 @@ def flash_attention(
q = q.astype(dtype)
k = k.astype(dtype)

# premultiply by 1/sqrt(d_k) for normal dot product attention
q = q * jax.lax.rsqrt(float(q.axis_size(Key)))
if block_size is None:
block_size = BLOCK_SIZE

QPos = q.resolve_axis(QPos)
KPos = k.resolve_axis(KPos)

if QPos.size < block_size or KPos.size < block_size:
from levanter.models.attention import simple_attention_with_dropout

return simple_attention_with_dropout(
QPos, KPos, Key, q, k, v, mask=mask, bias=bias, dropout=dropout, inference=inference, prng=key
)

# premultiply by 1/sqrt(d_k) for normal dot product attention
q = q / math.sqrt(float(q.axis_size(Key)))

return _flash_attention(
(q, k, v),
QPos,
Expand Down Expand Up @@ -145,6 +155,8 @@ def _flash_attention_forward(
ell = hax.zeros((*q_batch_axes, QPos))
ell = hax.auto_sharded(ell)

is_causal = isinstance(mask, AttentionMask) and mask.is_causal

@named_call
def do_o_block(state):
i, o, ell = state
Expand Down Expand Up @@ -187,8 +199,6 @@ def do_qk_block(state):
mask_ij = _materialize_mask_slice(mask, i, j, QPos, KPos, block_size)
attn_ij = hax.where(mask_ij, attn_ij, -1e10)

# TODO: block causal

if dropout > 0 and not inference:
attn_ij = hax.nn.dropout(attn_ij, dropout, inference=False, key=jax.random.fold_in(key, i * Tc + j))

Expand All @@ -205,8 +215,10 @@ def do_qk_block(state):

return (i, j + 1, o_i, q_i, sumexp_i, max_i)

j_end = jnp.minimum(i + 1, Tc) if is_causal else Tc

_, _, o_i, _, sumexp_i, max_i = jax.lax.while_loop(
lambda state: state[1] < Tc, do_qk_block, (i, 0, o_i, q_i, sumexp_i, max_i)
lambda state: state[1] < j_end, do_qk_block, (i, 0, o_i, q_i, sumexp_i, max_i)
)

# Step 12: compute O_i = diag(\ell_i^{Tc})^{-1} O_i^{Tc}
Expand Down Expand Up @@ -234,7 +246,7 @@ def _flash_attention_backward(
QPos: hax.Axis,
KPos: hax.Axis,
Key: hax.AxisSelector,
mask: Optional[hax.NamedArray] = None,
mask: Optional[AttentionMask | hax.NamedArray] = None,
bias: Optional[hax.NamedArray] = None,
dropout: float = 0.0,
*,
Expand Down Expand Up @@ -263,6 +275,8 @@ def _flash_attention_backward(
dK = (k * 0.0).astype(k.dtype)
dV = (v * 0.0).astype(v.dtype)

is_causal = isinstance(mask, AttentionMask) and mask.is_causal

@named_call
def do_kv_block(state):
j, dQ, dK, dV = state
Expand Down Expand Up @@ -322,7 +336,10 @@ def do_inner_block(state):
return i + 1, j, dQ, dK_j, dV_j

# dQ, dK_j, dV_j = hax.fold(do_inner_block, Tr)((dQ, dK_j, dV_j), jnp.arange(Tr.size))
i, j, dQ, dK_j, dV_j = jax.lax.while_loop(lambda state: state[0] < Tr, do_inner_block, (0, j, dQ, dK_j, dV_j))
i_start = j if is_causal else 0
i, j, dQ, dK_j, dV_j = jax.lax.while_loop(
lambda state: state[0] < Tr, do_inner_block, (i_start, j, dQ, dK_j, dV_j)
)

dK = dK.updated_slice({KPos: j * block_size}, dK_j)
dV = dV.updated_slice({KPos: j * block_size}, dV_j)
Expand Down
4 changes: 2 additions & 2 deletions src/levanter/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ class Gpt2Config(HFCompatConfig):

use_bias: bool = True

use_flash_attention: bool = False # use flash attention. This is a pure jax impl, and is not faster than normal, but it scales to long sequence lengths
flash_attention_block_size: int = 1024
use_flash_attention: bool = True # use flash attention. This is a pure jax impl, and is not faster than normal, but it scales to long sequence lengths
flash_attention_block_size: Optional[int] = None

# Axes
Pos = property(lambda self: Axis(name="position", size=self.seq_len))
Expand Down
2 changes: 1 addition & 1 deletion src/levanter/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class LlamaConfig(HFCompatConfig):

# Attention-related config
upcast_attn: bool = False
use_flash_attention: bool = False
use_flash_attention: bool = True
flash_attention_block_size: Optional[int] = None

gradient_checkpointing: bool = True
Expand Down
2 changes: 1 addition & 1 deletion src/levanter/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class MistralConfig(LlamaConfig):

# Attention-related config
upcast_attn: bool = False
use_flash_attention: bool = False
use_flash_attention: bool = True
flash_attention_block_size: Optional[int] = None

gradient_checkpointing: bool = True
Expand Down
2 changes: 2 additions & 0 deletions tests/gpt2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def test_gradient_checkpointing(num_blocks):
num_layers=num_blocks,
num_heads=8,
gradient_checkpointing=False,
use_flash_attention=False,
)
config_checkpoint = dataclasses.replace(config, gradient_checkpointing=True)
key = PRNGKey(0)
Expand Down Expand Up @@ -54,5 +55,6 @@ def test_pass_different_length_seq_to_gpt2():
num_layers=4,
num_heads=2,
gradient_checkpointing=False,
use_flash_attention=False,
)
check_model_works_with_seqlen(Gpt2LMHeadModel, config, 16)
1 change: 1 addition & 0 deletions tests/test_backpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,5 +131,6 @@ def test_pass_different_length_seq():
num_layers=4,
num_heads=2,
gradient_checkpointing=False,
use_flash_attention=False,
)
check_model_works_with_seqlen(BackpackLMHeadModel, config, 16)
1 change: 1 addition & 0 deletions tests/test_eval_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def test_eval_lm():
num_heads=2,
seq_len=32,
hidden_dim=32,
use_flash_attention=False,
)

with tempfile.TemporaryDirectory() as f:
Expand Down
2 changes: 2 additions & 0 deletions tests/test_export_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def test_export_lm_to_hf():
num_layers=2,
num_heads=2,
seq_len=32,
use_flash_attention=False,
hidden_dim=32,
)

Expand All @@ -44,6 +45,7 @@ def test_export_lm_to_hf():
num_layers=2,
num_heads=2,
seq_len=32,
use_flash_attention=False,
hidden_dim=32,
),
)
Expand Down
39 changes: 26 additions & 13 deletions tests/test_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
import haliax.nn as hnn

from levanter.models.attention import AttentionMask
from levanter.models.flash_attention import BLOCK_SIZE, flash_attention
from levanter.models.flash_attention import flash_attention


BLOCK_SIZE = 64


def test_flash_attention_acausal():
Expand All @@ -21,7 +24,7 @@ def test_flash_attention_acausal():
k = hax.random.normal(jrandom.PRNGKey(1), (KPos, Key))
v = hax.random.normal(jrandom.PRNGKey(2), (KPos, Key))

flash_out = flash_attention(QPos, KPos, Key, q, k, v, inference=True)
flash_out = flash_attention(QPos, KPos, Key, q, k, v, inference=True, block_size=BLOCK_SIZE)
hax_out = hnn.attention.dot_product_attention(KPos, Key, q, k, v)

assert hax_out.axes == flash_out.axes
Expand All @@ -30,16 +33,16 @@ def test_flash_attention_acausal():

def test_flash_attention_causal_mask():
Key = hax.Axis("Key", 8)
QPos = hax.Axis("QPos", BLOCK_SIZE * 2)
KPos = hax.Axis("KPos", BLOCK_SIZE * 2)
QPos = hax.Axis("QPos", BLOCK_SIZE * 4)
KPos = hax.Axis("KPos", BLOCK_SIZE * 4)

mask = AttentionMask.causal()

q = hax.random.normal(jrandom.PRNGKey(0), (QPos, Key))
k = hax.random.normal(jrandom.PRNGKey(1), (KPos, Key))
v = hax.random.normal(jrandom.PRNGKey(2), (KPos, Key))

flash_out = flash_attention(QPos, KPos, Key, q, k, v, inference=True, mask=mask)
flash_out = flash_attention(QPos, KPos, Key, q, k, v, inference=True, mask=mask, block_size=BLOCK_SIZE)
hax_out = hnn.attention.dot_product_attention(KPos, Key, q, k, v, mask=mask.materialize(QPos, KPos))

assert hax_out.axes == flash_out.axes
Expand All @@ -48,10 +51,10 @@ def test_flash_attention_causal_mask():

def test_grad_attention():
Key = hax.Axis("Key", 8)
QPos = hax.Axis("QPos", BLOCK_SIZE * 2)
KPos = hax.Axis("KPos", BLOCK_SIZE * 2)
QPos = hax.Axis("QPos", BLOCK_SIZE * 4)
KPos = hax.Axis("KPos", BLOCK_SIZE * 4)

mask = hax.nn.attention.causal_mask(QPos, KPos)
mask = AttentionMask.causal()

q = hax.random.normal(jrandom.PRNGKey(0), (QPos, Key))
k = hax.random.normal(jrandom.PRNGKey(1), (KPos, Key))
Expand All @@ -60,11 +63,17 @@ def test_grad_attention():
@equinox.filter_value_and_grad
def d_attn(qkv, fn):
q, k, v = qkv
x_out = fn(KPos, Key, q, k, v, mask=mask)
if fn is hnn.attention.dot_product_attention:
my_mask = mask.materialize(QPos, KPos)
else:
my_mask = mask
x_out = fn(KPos, Key, q, k, v, mask=my_mask)
return (x_out * x_out).sum().scalar()

hax_val, (hax_dq, hax_dk, hax_dv) = d_attn((q, k, v), hnn.attention.dot_product_attention)
fa_val, (fa_dq, fa_dk, fa_dv) = d_attn((q, k, v), functools.partial(flash_attention, QPos, inference=True))
fa_val, (fa_dq, fa_dk, fa_dv) = d_attn(
(q, k, v), functools.partial(flash_attention, QPos, inference=True, block_size=BLOCK_SIZE)
)

assert jnp.allclose(hax_val, fa_val, atol=1e-4, rtol=1e-4)
assert hax_dq.axes == fa_dq.axes
Expand Down Expand Up @@ -98,7 +107,9 @@ def d_attn(qkv, fn):
return (x_out * x_out).sum().scalar()

hax_val, (hax_dq, hax_dk, hax_dv) = d_attn((q, k, v), hnn.attention.dot_product_attention)
fa_val, (fa_dq, fa_dk, fa_dv) = d_attn((q, k, v), functools.partial(flash_attention, QPos, inference=True))
fa_val, (fa_dq, fa_dk, fa_dv) = d_attn(
(q, k, v), functools.partial(flash_attention, QPos, inference=True, block_size=BLOCK_SIZE)
)

assert jnp.allclose(hax_val, fa_val, atol=1e-4, rtol=1e-4)
assert hax_dq.axes == fa_dq.axes
Expand All @@ -123,8 +134,10 @@ def test_fa_dropout_does_something():

p_drop = 0.5

fa_with_dropout = functools.partial(flash_attention, inference=False, dropout=p_drop, key=jrandom.PRNGKey(3))
fa_without_dropout = functools.partial(flash_attention, inference=True)
fa_with_dropout = functools.partial(
flash_attention, inference=False, dropout=p_drop, key=jrandom.PRNGKey(3), block_size=BLOCK_SIZE
)
fa_without_dropout = functools.partial(flash_attention, inference=True, block_size=BLOCK_SIZE)

without_o = fa_without_dropout(QPos, KPos, Key, q, k, v, mask=mask)
with_o = fa_with_dropout(QPos, KPos, Key, q, k, v, mask=mask)
Expand Down
1 change: 1 addition & 0 deletions tests/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,5 +347,6 @@ def test_pass_different_length_seq(num_kv_heads):
intermediate_dim=32,
num_heads=2,
num_kv_heads=num_kv_heads,
use_flash_attention=False,
)
check_model_works_with_seqlen(LlamaLMHeadModel, config, 16)
3 changes: 2 additions & 1 deletion tests/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
save_merged_hf_model,
save_peft_pretrained,
)
from levanter.models.attention import AttentionMask
from levanter.models.gpt2 import Gpt2Config, Gpt2LMHeadModel
from levanter.trainer import StepInfo, TrainerState
from levanter.utils.tree_utils import inference_mode
Expand Down Expand Up @@ -214,7 +215,7 @@ def test_lora_merged_load_in_hf():
input = hax.random.randint(jax.random.PRNGKey(0), config.Pos, 0, Vocab.size)
torch_input = torch.tensor(np.array(input.array), dtype=torch.long).reshape((1, -1))

causal_mask = hax.nn.attention.causal_mask(model.Pos, config.KeyPos)
causal_mask = AttentionMask.causal()

with (tempfile.TemporaryDirectory() as tmpdir):
converter.save_pretrained(model, f"{tmpdir}/model")
Expand Down
1 change: 1 addition & 0 deletions tests/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,5 +161,6 @@ def test_pass_different_length_seq(num_kv_heads):
intermediate_dim=32,
num_heads=2,
num_kv_heads=num_kv_heads,
use_flash_attention=False,
)
check_model_works_with_seqlen(MistralLMHeadModel, config, 16)
1 change: 1 addition & 0 deletions tests/test_train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def test_train_lm():
num_heads=2,
seq_len=32,
hidden_dim=32,
flash_attention_block_size=32,
),
trainer=train_lm.TrainerConfig(
num_train_steps=2,
Expand Down
1 change: 1 addition & 0 deletions tests/test_viz_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def test_viz_lm():
num_heads=2,
hidden_dim=32,
seq_len=32,
use_flash_attention=False,
)

with tempfile.TemporaryDirectory() as f:
Expand Down

0 comments on commit 936cf82

Please sign in to comment.