From 01f1792f652632ac567b9c316f1f4e0dbb9ac172 Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 19 Dec 2024 15:40:08 -0800 Subject: [PATCH] ok bf16 works for attention now --- src/levanter/models/attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/levanter/models/attention.py b/src/levanter/models/attention.py index 7044200ba..b054f1972 100644 --- a/src/levanter/models/attention.py +++ b/src/levanter/models/attention.py @@ -806,10 +806,10 @@ def _tpu_splash_attention( if bias is not None: raise NotImplementedError("Splash attention does not support bias") - if attention_dtype is not None and attention_dtype != jnp.float32: - warnings.warn("Splash attention only supports float32. Switching to float32.") + # if attention_dtype is not None and attention_dtype != jnp.float32: + # warnings.warn("Splash attention only supports float32. Switching to float32.") - attention_dtype = jnp.float32 + # attention_dtype = jnp.float32 q_class, k_class, v_class = _bin_and_group_axes_by_function(query, key, value, QPos, KPos, Key)