diff --git a/src/levanter/models/gpt2.py b/src/levanter/models/gpt2.py index b3e8c08d5..5b97ea208 100644 --- a/src/levanter/models/gpt2.py +++ b/src/levanter/models/gpt2.py @@ -281,7 +281,7 @@ def true_fun(_): # not into ff for now # Operations if layer_idx equals 4 # sum over sequence length - return hax.sin(0.1*hax.sum(prev_x, axis='position')) + attn_output + ff_output + return hax.sin(*hax.sum(prev_x, axis='position')) + attn_output + ff_output def false_fun(_): # Otherwise return the same tensor # as expected