Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmeda14960 committed Feb 16, 2024
1 parent 446fb62 commit 4136e66
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/levanter/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def true_fun(_):
# not into ff for now
# Operations if layer_idx equals 4
# sum over sequence length
return hax.sin(0.1*hax.sum(prev_x, axis='position')) + attn_output + ff_output
return hax.sin(*hax.sum(prev_x, axis='position')) + attn_output + ff_output
def false_fun(_):
# Otherwise return the same tensor
# as expected
Expand Down

0 comments on commit 4136e66

Please sign in to comment.