Skip to content

Commit

Permalink
Need output to be a named array
Browse files Browse the repository at this point in the history
  • Loading branch information
Virginia Adams committed Feb 6, 2024
1 parent 8e47717 commit cf8e873
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/levanter/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def dot_product_attention(
print(f"scaling factor type: {type(scaling_factor)}")
print("\n\n")

return self_fused_attn(
attn_output = self_fused_attn(
qkv=qkv, # jnp.ndarray,
bias=fused_attn_bias, # jnp.ndarray,
mask=mask, # jnp.ndarray,
Expand All @@ -144,6 +144,9 @@ def dot_product_attention(
is_training=is_training, # bool,
)

attn_output = haliax.named(attn_output, ("batch", "position", "heads", "head_size"))
return attn_output

else:
QPos = query.resolve_axis(QPos)
KPos = key.resolve_axis(KPos)
Expand Down

0 comments on commit cf8e873

Please sign in to comment.