Skip to content

Commit

Permalink
Switch to sum -> divide
Browse files Browse the repository at this point in the history
  • Loading branch information
Aphoh committed Dec 18, 2024
1 parent fcf15ab commit 3ae3014
Show file tree
Hide file tree
Showing 10 changed files with 98 additions and 76 deletions.
4 changes: 2 additions & 2 deletions config/backpack.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ model:
trainer:
tracker:
project: "levanter"
tags: [ "openwebtext", "backpack" ]
tags: ["openwebtext", "backpack"]

mp: p=f32,c=bfloat16

Expand All @@ -21,5 +21,5 @@ trainer:
model_axis_size: 1

optimizer:
learning_rate: 6E-4
learning_rate: 6e-4
weight_decay: 0.1
40 changes: 21 additions & 19 deletions src/levanter/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from tqdm_loggable import tqdm_logging
from tqdm_loggable.auto import tqdm

import haliax as hax
import haliax.nn
from haliax import NamedArray, is_named_array
from haliax.jax_utils import is_jax_array_like
Expand All @@ -30,6 +31,8 @@
from levanter.utils import flop_utils, jax_utils
from levanter.utils.jax_utils import barrier_sync, jnp_to_python
from levanter.utils.logging import save_xla_dumps_to_wandb
from levanter.utils.stat_utils import RunningMean
from levanter.utils.types import Extras
from levanter.visualization import compute_and_visualize_log_probs as viz_probs


Expand Down Expand Up @@ -145,10 +148,8 @@ async def compute_length():


def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, name: Optional[str] = None):
total_loss = 0.0
total_load_time = 0.0
total_loss_time = 0.0
n = 0
loss = RunningMean(jnp.zeros(()), jnp.zeros(()))
extras: Extras = {}

if name is not None:
desc = f"eval {name}"
Expand All @@ -159,28 +160,27 @@ def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, n
pbar = tqdm(dataset, desc=desc, position=1, leave=False, total=max_batches)

iter_ = iter(pbar)
n = 0
while True:
time_in = time.time()
n += 1
batch = next(iter_, None)
if batch is None:
break
load_time = time.time() - time_in
total_load_time += load_time
loss = loss_fn(model, batch)
total_loss += loss.item()
n += 1
loss_time = time.time() - time_in - load_time
total_loss_time += loss_time
losses, where, extras = loss_fn(model, batch)
mean_loss = hax.mean(losses, where=where)
loss += RunningMean(mean_loss, where.sum())
for k, v in extras.items():
if k not in extras:
extras[k] = v
else:
extras[k] += v

pbar.set_postfix(loss=total_loss / n)
pbar.set_postfix(loss=loss.mean.item())

if max_batches is not None and n >= max_batches:
break

if n > 0:
total_loss /= n

return total_loss
return loss.item(), {k: v.item() for k, v in extras.items()}


def compute_validation_loss(
Expand All @@ -190,12 +190,14 @@ def compute_validation_loss(
name: Optional[str] = None,
):
def compute_loss(info: StepInfo):
loss = eval_loss_loop(loss_fn, info.model, dataset, max_batches=max_batches, name=name)
loss, extras = eval_loss_loop(loss_fn, info.model, dataset, max_batches=max_batches, name=name)

prefix = "eval"
if name:
prefix += "/" + name
levanter.tracker.log({f"{prefix}/loss": loss}, step=info.step)
levanter.tracker.log(
{f"{prefix}/loss": loss} | {f"{prefix}/{k}": v for k, v in extras.items()}, step=info.step
)

if name:
logger.info(f"{name} validation loss: {loss:.3f}")
Expand Down
9 changes: 7 additions & 2 deletions src/levanter/doremi.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,17 @@ def doremi_step(state: DoremiState, ref, batch, domains):
proxy = inference_mode(state.model, False)
with hax.axis_mapping(trainer.compute_axis_mapping):
# calculate per-token losses for proxy and ref
proxy_losses, proxy_loss_bwd = eqx.filter_vjp(lambda p: loss_fn(p, batch, reduction_axis=()), proxy)
ref_losses = loss_fn(ref, batch, reduction_axis=())
def scalar_loss_fn(p, batch):
ret, _, _ = loss_fn(p, batch)
return ret

proxy_losses, proxy_loss_bwd = eqx.filter_vjp(lambda p: scalar_loss_fn(p, batch), proxy)
ref_losses = scalar_loss_fn(ref, batch)

# calculate excess losses, aggregate per-domain losses
excess_losses = proxy_losses - ref_losses
clipped_losses = hax.maximum(excess_losses, 0)
print(clipped_losses.shape)
per_domain_losses = _compute_per_domain_losses(clipped_losses, Domain, domains)

# Update domain weights
Expand Down
47 changes: 11 additions & 36 deletions src/levanter/grad_accum.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import enum
import functools
from typing import Callable, Optional, ParamSpec, TypeVar

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.tree as jtu
from jax.lax import with_sharding_constraint
from jax.sharding import PartitionSpec

Expand All @@ -25,34 +23,6 @@
X = TypeVar("X", contravariant=True) # Input


class ReductionType(enum.Enum):
SUM = enum.auto()
MEAN = enum.auto()
# TODO: add MAX?


def apply_updates_running(acc, r, updates, overwrites):
def _running_sum_updates(u, p):
if u is None:
return p
else:
return p * (1 - r) + u * r

def _is_none(x):
return x is None

def _apply_update(tree, update, overwrite):
if overwrite is not None:
return overwrite

return jtu.map(_running_sum_updates, update, tree, is_leaf=_is_none)

def is_leaf(x):
return x is None or isinstance(x, hq.OverwriteWithGradient)

return jtu.map(_apply_update, acc, updates, overwrites, is_leaf=is_leaf)


# TODO: should we use a custom_jvp on microbatched?

# cf https://github.com/google-research/t5x/blob/main/t5x/trainer.py#L617
Expand Down Expand Up @@ -108,6 +78,8 @@ def microbatched(
@functools.wraps(loss_fn)
def no_accum_loss_fn(*args, **kwargs):
losses, where, extras = loss_fn(*args, **kwargs)
seen_tokens = where.sum().scalar()
extras["seen_tokens"] = seen_tokens
return hax.mean(losses, where=where).scalar(), extras

return eqx.filter_value_and_grad(no_accum_loss_fn, has_aux=True)
Expand All @@ -119,7 +91,7 @@ def no_accum_loss_fn(*args, **kwargs):
@functools.wraps(loss_fn)
def accum_loss_fn(*args, **kwargs):
losses, where, extras = loss_fn(*args, **kwargs)
return hax.mean(losses, where=where).scalar(), (where.sum(), extras)
return hax.sum(losses, where=where).scalar(), (where.sum(), extras)

grad_fn = eqx.filter_value_and_grad(accum_loss_fn, has_aux=True)

Expand Down Expand Up @@ -154,17 +126,20 @@ def loop(acc, microbatch_and_key):

# TODO: this uses the latest value for the scale for fp8, which seems not ideal but probably ok?
overwrites, updates = hq.partition_for_grad_overwrite(grads_mb)
r = n_mb / (total + n_mb)
loss = loss + (loss_mb - loss) * r
grads = apply_updates_running(grads, r, updates, overwrites)
grads = hq.apply_updates(grads, updates, overwrites)
grads = hax.shard_with_axis_mapping(grads, accum_axis_mapping)
print(loss, loss_mb, r)
loss += loss_mb
total += n_mb

return (loss, (total, {k: v + extras_mb[k] for k, v in extras.items()})), grads

with jax.named_scope("microbatched"):
(loss, (_, extras)), grads, = hax.fold(
(loss, (total, extras)), grads, = hax.fold(
loop, AccumStep
)(acc, (args, kwargs, key))
grads = jax.tree_util.tree_map(lambda x: x / total, grads)
loss /= total
extras["seen_tokens"] = total

return (loss, extras), grads

Expand Down
23 changes: 14 additions & 9 deletions src/levanter/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from levanter.utils import cloud_utils, fsspec_utils
from levanter.utils.jax_utils import create_fsdp_mesh, zeros_like_tree
from levanter.utils.tree_utils import inference_mode
from levanter.utils.types import ComputeLossFunction, FilterSpec
from levanter.utils.types import ComputeLossFunction, Extras, FilterSpec


logger = pylogging.getLogger(__name__)
Expand Down Expand Up @@ -391,10 +391,10 @@ def train_step(self, state: S, batch: X, **batch_kwargs) -> StepInfo[S]:

with capture_time() as step_time:
if hooks_this_time:
loss, new_state, cb_states = self._jit_train_step_fn(state, batch, batch_kwargs)
loss, new_state, extras, cb_states = self._jit_train_step_fn(state, batch, batch_kwargs)
# force the loss so timing numbers are accurate. laziness isn't going to help here (i think?)
else:
loss, new_state = self._jit_train_step_fn_no_hook(state, batch, batch_kwargs)
loss, new_state, extras = self._jit_train_step_fn_no_hook(state, batch, batch_kwargs)
loss = loss.item() # type: ignore

info = StepInfo(new_state, loss, step_time())
Expand All @@ -404,7 +404,8 @@ def train_step(self, state: S, batch: X, **batch_kwargs) -> StepInfo[S]:
if hooks_this_time:
self.hooks.run_jit_hooks_outside_step(info, cb_states)

levanter.tracker.log({"throughput/hook_time": hook_time()}, step=info.step)
log_items = {k: v.item() for k, v in extras.items()} | {"throughput/hook_time": hook_time()}
levanter.tracker.log(log_items, step=info.step)

return info

Expand Down Expand Up @@ -525,11 +526,13 @@ def _jit_train_step_fn_no_hook(self):

def _train_step(
self, state: S, batch, batch_kwargs, _no_hooks=False
) -> tuple[Scalar, S, Sequence[CBInfo]] | tuple[Scalar, S]:
) -> tuple[Scalar, S, Extras, Sequence[CBInfo]] | tuple[Scalar, S, Extras]:
key, new_key = jax.random.split(state.training_key)
model = inference_mode(state.model, False)

loss, grads = self._compute_gradients_microbatched(self.loss_fn, model, batch, **batch_kwargs, key=key)
(loss, extras), grads = self._compute_gradients_microbatched(
self.loss_fn, model, batch, **batch_kwargs, key=key
)

with hax.axis_mapping(self.parameter_axis_mapping):
if not _no_hooks:
Expand All @@ -545,11 +548,13 @@ def obj_fun(trainable_model):
new_state = state.take_step(grads, obj_fun=obj_fun)
new_state = hax.shard(new_state, self.parameter_axis_mapping)
if _no_hooks:
return loss, new_state
return loss, new_state, extras
else:
return loss, new_state, hook_infos
return loss, new_state, extras, hook_infos

def _compute_gradients_microbatched(self, loss_fn, model: M, batch: X, **batch_kwargs) -> tuple[Scalar, M]:
def _compute_gradients_microbatched(
self, loss_fn, model: M, batch: X, **batch_kwargs
) -> tuple[tuple[Scalar, Extras], M]:
mbs = self.config.microbatch_size
grad_fn = microbatched(
loss_fn,
Expand Down
20 changes: 18 additions & 2 deletions src/levanter/utils/stat_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,26 @@
import typing
from typing import TypeAlias

import equinox as eqx
import jax.numpy as jnp
import numpy as np
from typing_extensions import Self

import haliax as hax

from levanter.utils.types import Accumulatable

Arrayish: typing.TypeAlias = hax.NamedArray | np.ndarray | jnp.ndarray

Arrayish: TypeAlias = hax.NamedArray | np.ndarray | jnp.ndarray


class SumScalar(Accumulatable):
value: jnp.ndarray

def item(self) -> float:
return self.value.item()

def __add__(self, other: Self) -> Self:
return SumScalar(self.value + other.value)


class RunningMean(eqx.Module):
Expand All @@ -27,6 +40,9 @@ def add(self, x: Arrayish, total: Arrayish) -> "RunningMean":
new_total = self.total + total
return RunningMean(new_mean, new_total)

def item(self) -> float:
return self.mean.item()

def __add__(self, other: "RunningMean"):
return self.add(other.mean, other.total)

Expand Down
21 changes: 19 additions & 2 deletions src/levanter/utils/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import Any, Callable, Protocol, Tuple, TypeVar, Union
import abc
from typing import Any, Callable, Dict, Protocol, Tuple, TypeAlias, TypeVar, Union

import equinox as eqx
import jax
from jaxtyping import PyTree
from typing_extensions import Self

import haliax as hax
from haliax.types import Scalar
Expand All @@ -10,6 +14,19 @@
M_con = TypeVar("M_con", contravariant=True) # Model
X = TypeVar("X", contravariant=True) # Input


class Accumulatable(abc.ABC, eqx.Module):
@abc.abstractmethod
def item(self) -> float:
pass

@abc.abstractmethod
def __add__(self, other: Self) -> Self:
pass


Extras: TypeAlias = Dict[str, jax.Array | Accumulatable]

try:
from haliax.nn.scan import BlockFoldable
except ImportError:
Expand Down Expand Up @@ -53,5 +70,5 @@ def __call__(
model: M_con,
input: X,
**kwargs,
) -> tuple[hax.NamedArray, hax.NamedArray, dict]:
) -> tuple[hax.NamedArray, hax.NamedArray, Extras]:
...
5 changes: 3 additions & 2 deletions tests/test_doremi.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,11 @@ def test_estimate_mixture_weights():
ds3 = LogitDataset(W3, 0.05, x3_mask, x3_bias, key=next(keys))

# TODO: remove key as a requirement for models
def compute_loss_fn(model, example, reduction=hax.mean, reduction_axis=None, key=None):
def compute_loss_fn(model, example, key=None):
del key
y_pred = model(example.x)
return hax.nn.binary_cross_entropy_loss(y_pred, example.y, reduction=reduction, reduction_axis=reduction_axis)
losses = hax.nn.binary_cross_entropy_loss(y_pred, example.y, reduction=None)
return losses, hax.ones_like(losses), {}

tiny_trainer_config = TrainerConfig(
num_train_steps=300,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_grad_accum.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def scalar_loss_fn(mlp, x):

mesh = Mesh(jax.devices(), ("data",))

# @hax.partitioning.named_jit(axis_resources=axis_mapping)
@hax.partitioning.named_jit(axis_resources=axis_mapping)
def jit_grad_accum(mlp, x):
grad_fn = microbatched(loss_fn, Batch, parallelism, axis_mapping, axis_mapping)
return grad_fn(mlp, x)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_hf_gpt2_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ def torch_loss(model, input_ids) -> torch.Tensor:

def compute_loss(model: LmHeadModel, input_ids):
example = LmExample.causal(input_ids, eos_id=converter.tokenizer.eos_token_id)
return compute_next_token_loss(model, example, key=None).scalar()
loss, where, _ = compute_next_token_loss(model, example, key=None)
return hax.mean(loss, where=where).scalar()

jax_compute_grad = equinox.filter_value_and_grad(compute_loss, has_aux=False)
jax_grad: Gpt2LMHeadModel
Expand Down

0 comments on commit 3ae3014

Please sign in to comment.