Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmeda14960 committed Feb 16, 2024
1 parent eea0fa0 commit 09a1f9f
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions src/levanter/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass
from functools import partial
from typing import Callable, Dict, Optional, Type
import levanter

import equinox as eqx
from jax import lax
Expand Down Expand Up @@ -281,18 +282,25 @@ def true_fun(_):
# not into ff for now
# Operations if layer_idx equals 4
# sum over sequence length
return hax.sin(*hax.sum(prev_x, axis='position')) + attn_output + ff_output
mode = "integrated"
if mode == "integrated":
return hax.sin(hax.sum(prev_x, axis='position')) + attn_output + ff_output
elif mode == "mod":
return hax.sin(0.1*hax.sum(x, axis='position')) + ff_output
else:
return x + ff_output
def false_fun(_):
# Otherwise return the same tensor
# as expected
return x + ff_output

# if layer is one of the last three (12 layers) add sine target
sin_target = lax.cond(jnp.greater_equal(layer_idx.array, 20), true_fun, false_fun, None)
sin_target = lax.cond(jnp.greater_equal(layer_idx.array, 12), true_fun, false_fun, None)

# jax.lax.stopgradient

x = x + ff_output
#x = x + hax.sin(hax.sum(ff_output, axis='position')) + ff_output
x = x + ff_output# sin_target
activation_diff = hax.square(x - sin_target)
return x, activation_diff

Expand Down Expand Up @@ -440,6 +448,14 @@ def compute_loss(
targets = hax.roll(example.tokens, -1, axis=self.Pos.name)
target_y = hax.nn.one_hot(targets, self.Vocab, dtype=logits.dtype)


if key is None:
return cross_entropy_loss(
logits, self.Vocab, target_y, reduction, reduction_axis=reduction_axis, where=example.loss_mask
)
return hax.mean(sine_output), cross_entropy_loss(
logits, self.Vocab, target_y, reduction, reduction_axis=reduction_axis, where=example.loss_mask
)
if key is None:
return cross_entropy_loss(
logits, self.Vocab, target_y, reduction, reduction_axis=reduction_axis, where=example.loss_mask
Expand Down

0 comments on commit 09a1f9f

Please sign in to comment.