diff --git a/src/levanter/models/attention.py b/src/levanter/models/attention.py index 701e0ea09..e355c883c 100644 --- a/src/levanter/models/attention.py +++ b/src/levanter/models/attention.py @@ -1,8 +1,7 @@ +import math from typing import Optional, Union, overload import equinox as eqx - -# import jax import jax.numpy as jnp from jax.lib import xla_bridge from jaxtyping import PRNGKeyArray @@ -63,10 +62,11 @@ def dot_product_attention( accelerator_type = xla_bridge.get_backend().platform.lower() if accelerator_type == "tpu" or accelerator_type == "cpu": - from levanter.models.flash_attention import BLOCK_SIZE, flash_attention + from levanter.models.flash_attention import flash_attention if flash_block_size is None: - flash_block_size = BLOCK_SIZE + seq_len = query.shape["position"] + flash_block_size = seq_len return flash_attention( QPos, @@ -88,49 +88,32 @@ def dot_product_attention( from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, self_fused_attn # TransformerEngine looks for qkv shape: - # batch_shape, max_seqlen, nqkv, num_heads, head_dim - # where nqkv = 3 + # batch_shape, max_seqlen, nqkv, num_heads, head_dim; where nqkv = 3 q_ = haliax.rearrange(query, ("batch", "position", "heads", "head_size")).array k_ = haliax.rearrange(key, ("batch", "key_position", "heads", "head_size")).array v_ = haliax.rearrange(value, ("batch", "key_position", "heads", "head_size")).array - # print(f"\nShape of q_ is now {q_.shape}\n") - qkv = jnp.stack((q_, k_, v_), axis=2) - # print(f"\nQKV shape is now: {qkv.shape}") - # scaling_factor = jax.lax.rsqrt(float(query.axis_size(Key))) - scaling_factor = 1.0 - # print(f"\n\nScaling factor: {scaling_factor}\n\n") + scaling_factor = 1 / math.sqrt(float(query.axis_size(Key))) is_training = not inference # TODO: bias type is probably also configurable attn_bias_type = AttnBiasType.NO_BIAS fused_attn_bias = None if bias: - fused_attn_bias = bias.array + raise Exception( + "Using bias with flash attention on GPU is not currently implemented. Try setting" + " model.use_bias=False" + ) + # fused_attn_bias = bias.array # TODO: We have a mask type we can use to configure this attn_mask_type = AttnMaskType.CAUSAL_MASK - # print(f"\nQuery Shape: {query.shape}\n") batch_size = query.shape["batch"] seq_len = query.shape["position"] mask = jnp.tril(jnp.ones((batch_size, seq_len, seq_len))) - # fused_attn_mask = None - - # if mask: - # fused_attn_mask = mask.array - - # print("\n\n") - # print(f"qkv type: {type(qkv)}") - # print(f"bias type: {type(bias)}") - # print(f"mask type: {type(mask)}") - # print(f"seed type: {type(prng)}") - # print(f"attn_bias_Type type: {type(attn_bias_type)}") - # print(f"attn_mask_type type: {type(attn_mask_type)}") - # print(f"scaling factor type: {type(scaling_factor)}") - # print("\n\n") attn_output = self_fused_attn( qkv=qkv, # jnp.ndarray,