diff --git a/src/levanter/models/attention.py b/src/levanter/models/attention.py index 7289abed2..ff6ba8876 100644 --- a/src/levanter/models/attention.py +++ b/src/levanter/models/attention.py @@ -132,7 +132,7 @@ def dot_product_attention( print(f"scaling factor type: {type(scaling_factor)}") print("\n\n") - return self_fused_attn( + attn_output = self_fused_attn( qkv=qkv, # jnp.ndarray, bias=fused_attn_bias, # jnp.ndarray, mask=mask, # jnp.ndarray, @@ -144,6 +144,9 @@ def dot_product_attention( is_training=is_training, # bool, ) + attn_output = haliax.named(attn_output, ("batch", "position", "heads", "head_size")) + return attn_output + else: QPos = query.resolve_axis(QPos) KPos = key.resolve_axis(KPos)