From 09a1f9f9f115a4ff4851a1a50efa45f6190cd35b Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Fri, 16 Feb 2024 10:31:28 -0800 Subject: [PATCH] fix --- src/levanter/models/gpt2.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/levanter/models/gpt2.py b/src/levanter/models/gpt2.py index 5b97ea208..080b19ee8 100644 --- a/src/levanter/models/gpt2.py +++ b/src/levanter/models/gpt2.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from functools import partial from typing import Callable, Dict, Optional, Type +import levanter import equinox as eqx from jax import lax @@ -281,18 +282,25 @@ def true_fun(_): # not into ff for now # Operations if layer_idx equals 4 # sum over sequence length - return hax.sin(*hax.sum(prev_x, axis='position')) + attn_output + ff_output + mode = "integrated" + if mode == "integrated": + return hax.sin(hax.sum(prev_x, axis='position')) + attn_output + ff_output + elif mode == "mod": + return hax.sin(0.1*hax.sum(x, axis='position')) + ff_output + else: + return x + ff_output def false_fun(_): # Otherwise return the same tensor # as expected return x + ff_output # if layer is one of the last three (12 layers) add sine target - sin_target = lax.cond(jnp.greater_equal(layer_idx.array, 20), true_fun, false_fun, None) + sin_target = lax.cond(jnp.greater_equal(layer_idx.array, 12), true_fun, false_fun, None) # jax.lax.stopgradient - x = x + ff_output + #x = x + hax.sin(hax.sum(ff_output, axis='position')) + ff_output + x = x + ff_output# sin_target activation_diff = hax.square(x - sin_target) return x, activation_diff @@ -440,6 +448,14 @@ def compute_loss( targets = hax.roll(example.tokens, -1, axis=self.Pos.name) target_y = hax.nn.one_hot(targets, self.Vocab, dtype=logits.dtype) + + if key is None: + return cross_entropy_loss( + logits, self.Vocab, target_y, reduction, reduction_axis=reduction_axis, where=example.loss_mask + ) + return hax.mean(sine_output), cross_entropy_loss( + logits, self.Vocab, target_y, reduction, reduction_axis=reduction_axis, where=example.loss_mask + ) if key is None: return cross_entropy_loss( logits, self.Vocab, target_y, reduction, reduction_axis=reduction_axis, where=example.loss_mask