Skip to content

Commit

Permalink
ok bf16 works for attention now
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Dec 19, 2024
1 parent 95f793e commit 01f1792
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/levanter/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,10 +806,10 @@ def _tpu_splash_attention(
if bias is not None:
raise NotImplementedError("Splash attention does not support bias")

if attention_dtype is not None and attention_dtype != jnp.float32:
warnings.warn("Splash attention only supports float32. Switching to float32.")
# if attention_dtype is not None and attention_dtype != jnp.float32:
# warnings.warn("Splash attention only supports float32. Switching to float32.")

attention_dtype = jnp.float32
# attention_dtype = jnp.float32

q_class, k_class, v_class = _bin_and_group_axes_by_function(query, key, value, QPos, KPos, Key)

Expand Down

0 comments on commit 01f1792

Please sign in to comment.