Skip to content

Commit

Permalink
Misc fixes from sweep (disable blocked CE by default) (#798)
Browse files Browse the repository at this point in the history
blocked CE is worse except it enables larger batch sizes
  • Loading branch information
dlwh authored Nov 12, 2024
1 parent 2195263 commit 433490e
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 17 deletions.
10 changes: 6 additions & 4 deletions src/levanter/doremi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses
import logging
from typing import Optional, Tuple, TypeVar
from typing import Mapping, Optional, Tuple, TypeVar

import equinox as eqx
import jax.numpy as jnp
Expand Down Expand Up @@ -56,7 +56,7 @@ def estimate_mixture_weights(
loss_fn: ComputeLossFunction[M, T],
initial_proxy: M,
ref: M,
data_sources: dict[str, AsyncDataset[T]],
data_sources: Mapping[str, AsyncDataset[T]],
sampling_weights: Optional[dict[str, float]] = None,
*,
validation_sets: Optional[dict[str, AsyncDataset[T]]] = None,
Expand Down Expand Up @@ -184,7 +184,9 @@ def doremi_step(state: DoremiState, ref, batch, domains):

# we're not actually going to use the trainer for very much but it holds hooks and sets up contexts
with trainer:
tagged_mixture = domain_tagged_mixture(data_sources, sampling_weights, domain_to_index, key=data_key)
tagged_mixture: MixtureDataset = domain_tagged_mixture(
data_sources, sampling_weights, domain_to_index, key=data_key
)
state = load_checkpoint_or_initialize(
DoremiState.init,
trainer.checkpoint_path,
Expand Down Expand Up @@ -263,7 +265,7 @@ def _prepare_ref_model(ref, trainer):


def domain_tagged_mixture(
data_sources: dict[str, AsyncDataset[T]],
data_sources: Mapping[str, AsyncDataset[T]],
weights: dict[str, float],
domain_to_index: dict[str, int],
*,
Expand Down
4 changes: 2 additions & 2 deletions src/levanter/main/doremi_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def init_proxy_model():
train_datasets = config.data.training_sets(ref_model.Pos.size)
valid_datasets = config.data.validation_sets(ref_model.Pos.size)

train_datasets = {
causal_train_datasets = {
k: CausalLmDataset(v, config.model.Pos, config.model.KeyPos, ignore_index=config.data.ignore_token_id)
for k, v in train_datasets.items()
}
Expand All @@ -122,7 +122,7 @@ def init_proxy_model():
loss_function,
proxy_model,
ref=ref_model,
data_sources=train_datasets,
data_sources=causal_train_datasets,
trainer_config=config.trainer,
optimizer=optimizer,
domain_weight_step_size=config.doremi.domain_weight_step_size,
Expand Down
4 changes: 3 additions & 1 deletion src/levanter/models/lm_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
from dataclasses import dataclass
from typing import Generic, Optional, Type, TypeVar

import draccus
Expand Down Expand Up @@ -48,6 +49,7 @@ def causal(


# TODO: for some reason, mypy doesn't like the discover_packages_path argument?
@dataclass(frozen=True)
class LmConfig(draccus.PluginRegistry, abc.ABC, Generic[LmT], discover_packages_path="levanter.models"): # type: ignore
@property
@abc.abstractmethod
Expand All @@ -69,7 +71,7 @@ def Pos(self) -> Axis:
def Embed(self) -> Axis:
pass

cross_entropy_block_size: Optional[int] = 64000
cross_entropy_block_size: Optional[int] = None
"""
The block size for computing cross-entropy loss. This is the number of tokens that are processed together
in a single block. This can be adjusted to fit within memory constraints. It's deliberately set to a large
Expand Down
22 changes: 13 additions & 9 deletions src/levanter/models/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def next_token_loss(

if block_size is None:
# Full softmax computation
logits = hax.dot(pred_embeddings, pred_lm_head, axis=Embed, preferred_element_type=dtype)
logits = hax.dot(pred_embeddings, pred_lm_head, axis=Embed)
if dtype is not None:
logits = logits.astype(dtype)
target_y_full = hax.nn.one_hot(target_y, Vocab, dtype=pred_embeddings.dtype)
return cross_entropy_and_logsumexp_penalty(
logits,
Expand Down Expand Up @@ -261,9 +263,10 @@ def process_block(block_idx, acc, current_block_size):

# Materialize the logits for the current block
lm_head_b = pred_lm_head[Label, hax.dslice(start, Block)] # [Contract, Block]
logits_b = hax.dot(
pred_embeddings, lm_head_b, axis=Contract, preferred_element_type=dtype
) # [Batch, Seq, Block]
logits_b = hax.dot(pred_embeddings, lm_head_b, axis=Contract) # [Batch, Seq, Block]

if dtype is not None:
logits_b = logits_b.astype(dtype)

# Update max and logsumexp
max_logit = hax.maximum(max_logit_prev, hax.max(logits_b, axis=Block)) # [Batch, Seq]
Expand All @@ -278,7 +281,7 @@ def process_block(block_idx, acc, current_block_size):
# Update sumV. This is actually unnecessary if we're using one-hot targets
# sV = sV_prev + hax.sum(target_y_b, axis=Label.name)

loss += hax.dot(logits_b, target_y_b, axis=Block, preferred_element_type=dtype) # [Batch, Seq]
loss += hax.dot(logits_b, target_y_b, axis=Block) # [Batch, Seq]

return loss, logsumexp, max_logit # , sV

Expand Down Expand Up @@ -351,7 +354,7 @@ def _block_cross_entropy_backward(
num_blocks = vocab_size // block_size

grad_embeddings = hax.zeros(pred_embeddings.axes, dtype=pred_embeddings.dtype)
grad_lm_head = hax.zeros(pred_lm_head.axes, dtype=pred_embeddings.dtype)
grad_lm_head = hax.zeros(pred_lm_head.axes, dtype=pred_lm_head.dtype)

def process_block(block_idx, acc, current_block_size):
"""
Expand All @@ -372,14 +375,15 @@ def process_block(block_idx, acc, current_block_size):

# Materialize the logits for the current block
lm_head_b = pred_lm_head[Label, hax.dslice(start, Block)] # [Contract, Block]
logits_b = hax.dot(
pred_embeddings, lm_head_b, axis=Contract, preferred_element_type=dtype
) # [Batch, Seq, Block]
logits_b = hax.dot(pred_embeddings, lm_head_b, axis=Contract) # [Batch, Seq, Block]

# Materialize the target for the current block (one-hot)
target_y_block = _block_one_hot(Block, start, labels_y, logits_b.dtype) # [Batch, Seq, Block]

# materialize the softmax for the current block
if dtype is not None:
logits_b = logits_b.astype(dtype)

p_b = hax.exp(logits_b - log_z) # [Batch, Seq, Block]

delta_b = p_b - target_y_block
Expand Down
2 changes: 1 addition & 1 deletion src/levanter/models/mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def from_hf(config: HfMptAttentionConfig):


@LmConfig.register_subclass("mpt")
@dataclass
@dataclass(frozen=True)
class MptConfig(HFCompatConfig):
d_model: int = 768
n_heads: int = 12
Expand Down

0 comments on commit 433490e

Please sign in to comment.