Skip to content

Commit

Permalink
fixed some scaling value, throw error when bias used
Browse files Browse the repository at this point in the history
  • Loading branch information
Virginia Adams committed Feb 14, 2024
1 parent c5df598 commit 4ed3217
Showing 1 changed file with 11 additions and 28 deletions.
39 changes: 11 additions & 28 deletions src/levanter/models/attention.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import math
from typing import Optional, Union, overload

import equinox as eqx

# import jax
import jax.numpy as jnp
from jax.lib import xla_bridge
from jaxtyping import PRNGKeyArray
Expand Down Expand Up @@ -63,10 +62,11 @@ def dot_product_attention(
accelerator_type = xla_bridge.get_backend().platform.lower()

if accelerator_type == "tpu" or accelerator_type == "cpu":
from levanter.models.flash_attention import BLOCK_SIZE, flash_attention
from levanter.models.flash_attention import flash_attention

if flash_block_size is None:
flash_block_size = BLOCK_SIZE
seq_len = query.shape["position"]
flash_block_size = seq_len

return flash_attention(
QPos,
Expand All @@ -88,49 +88,32 @@ def dot_product_attention(
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, self_fused_attn

# TransformerEngine looks for qkv shape:
# batch_shape, max_seqlen, nqkv, num_heads, head_dim
# where nqkv = 3
# batch_shape, max_seqlen, nqkv, num_heads, head_dim; where nqkv = 3
q_ = haliax.rearrange(query, ("batch", "position", "heads", "head_size")).array
k_ = haliax.rearrange(key, ("batch", "key_position", "heads", "head_size")).array
v_ = haliax.rearrange(value, ("batch", "key_position", "heads", "head_size")).array

# print(f"\nShape of q_ is now {q_.shape}\n")

qkv = jnp.stack((q_, k_, v_), axis=2)

# print(f"\nQKV shape is now: {qkv.shape}")

# scaling_factor = jax.lax.rsqrt(float(query.axis_size(Key)))
scaling_factor = 1.0
# print(f"\n\nScaling factor: {scaling_factor}\n\n")
scaling_factor = 1 / math.sqrt(float(query.axis_size(Key)))
is_training = not inference

# TODO: bias type is probably also configurable
attn_bias_type = AttnBiasType.NO_BIAS
fused_attn_bias = None
if bias:
fused_attn_bias = bias.array
raise Exception(
"Using bias with flash attention on GPU is not currently implemented. Try setting"
" model.use_bias=False"
)
# fused_attn_bias = bias.array

# TODO: We have a mask type we can use to configure this
attn_mask_type = AttnMaskType.CAUSAL_MASK
# print(f"\nQuery Shape: {query.shape}\n")
batch_size = query.shape["batch"]
seq_len = query.shape["position"]
mask = jnp.tril(jnp.ones((batch_size, seq_len, seq_len)))
# fused_attn_mask = None

# if mask:
# fused_attn_mask = mask.array

# print("\n\n")
# print(f"qkv type: {type(qkv)}")
# print(f"bias type: {type(bias)}")
# print(f"mask type: {type(mask)}")
# print(f"seed type: {type(prng)}")
# print(f"attn_bias_Type type: {type(attn_bias_type)}")
# print(f"attn_mask_type type: {type(attn_mask_type)}")
# print(f"scaling factor type: {type(scaling_factor)}")
# print("\n\n")

attn_output = self_fused_attn(
qkv=qkv, # jnp.ndarray,
Expand Down

0 comments on commit 4ed3217

Please sign in to comment.