From 433490e3858b87b5c3332d2bff7feaf48bc6124f Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 12 Nov 2024 13:54:50 -0800 Subject: [PATCH] Misc fixes from sweep (disable blocked CE by default) (#798) blocked CE is worse except it enables larger batch sizes --- src/levanter/doremi.py | 10 ++++++---- src/levanter/main/doremi_lm.py | 4 ++-- src/levanter/models/lm_model.py | 4 +++- src/levanter/models/loss.py | 22 +++++++++++++--------- src/levanter/models/mpt.py | 2 +- 5 files changed, 25 insertions(+), 17 deletions(-) diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index 9d048b24f..6d9165cfc 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -1,6 +1,6 @@ import dataclasses import logging -from typing import Optional, Tuple, TypeVar +from typing import Mapping, Optional, Tuple, TypeVar import equinox as eqx import jax.numpy as jnp @@ -56,7 +56,7 @@ def estimate_mixture_weights( loss_fn: ComputeLossFunction[M, T], initial_proxy: M, ref: M, - data_sources: dict[str, AsyncDataset[T]], + data_sources: Mapping[str, AsyncDataset[T]], sampling_weights: Optional[dict[str, float]] = None, *, validation_sets: Optional[dict[str, AsyncDataset[T]]] = None, @@ -184,7 +184,9 @@ def doremi_step(state: DoremiState, ref, batch, domains): # we're not actually going to use the trainer for very much but it holds hooks and sets up contexts with trainer: - tagged_mixture = domain_tagged_mixture(data_sources, sampling_weights, domain_to_index, key=data_key) + tagged_mixture: MixtureDataset = domain_tagged_mixture( + data_sources, sampling_weights, domain_to_index, key=data_key + ) state = load_checkpoint_or_initialize( DoremiState.init, trainer.checkpoint_path, @@ -263,7 +265,7 @@ def _prepare_ref_model(ref, trainer): def domain_tagged_mixture( - data_sources: dict[str, AsyncDataset[T]], + data_sources: Mapping[str, AsyncDataset[T]], weights: dict[str, float], domain_to_index: dict[str, int], *, diff --git a/src/levanter/main/doremi_lm.py b/src/levanter/main/doremi_lm.py index 12b3e6ae0..742c3229c 100644 --- a/src/levanter/main/doremi_lm.py +++ b/src/levanter/main/doremi_lm.py @@ -109,7 +109,7 @@ def init_proxy_model(): train_datasets = config.data.training_sets(ref_model.Pos.size) valid_datasets = config.data.validation_sets(ref_model.Pos.size) - train_datasets = { + causal_train_datasets = { k: CausalLmDataset(v, config.model.Pos, config.model.KeyPos, ignore_index=config.data.ignore_token_id) for k, v in train_datasets.items() } @@ -122,7 +122,7 @@ def init_proxy_model(): loss_function, proxy_model, ref=ref_model, - data_sources=train_datasets, + data_sources=causal_train_datasets, trainer_config=config.trainer, optimizer=optimizer, domain_weight_step_size=config.doremi.domain_weight_step_size, diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index 911e74b09..1a82aa7be 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -1,4 +1,5 @@ import abc +from dataclasses import dataclass from typing import Generic, Optional, Type, TypeVar import draccus @@ -48,6 +49,7 @@ def causal( # TODO: for some reason, mypy doesn't like the discover_packages_path argument? +@dataclass(frozen=True) class LmConfig(draccus.PluginRegistry, abc.ABC, Generic[LmT], discover_packages_path="levanter.models"): # type: ignore @property @abc.abstractmethod @@ -69,7 +71,7 @@ def Pos(self) -> Axis: def Embed(self) -> Axis: pass - cross_entropy_block_size: Optional[int] = 64000 + cross_entropy_block_size: Optional[int] = None """ The block size for computing cross-entropy loss. This is the number of tokens that are processed together in a single block. This can be adjusted to fit within memory constraints. It's deliberately set to a large diff --git a/src/levanter/models/loss.py b/src/levanter/models/loss.py index 154fc66ac..d705eda4d 100644 --- a/src/levanter/models/loss.py +++ b/src/levanter/models/loss.py @@ -58,7 +58,9 @@ def next_token_loss( if block_size is None: # Full softmax computation - logits = hax.dot(pred_embeddings, pred_lm_head, axis=Embed, preferred_element_type=dtype) + logits = hax.dot(pred_embeddings, pred_lm_head, axis=Embed) + if dtype is not None: + logits = logits.astype(dtype) target_y_full = hax.nn.one_hot(target_y, Vocab, dtype=pred_embeddings.dtype) return cross_entropy_and_logsumexp_penalty( logits, @@ -261,9 +263,10 @@ def process_block(block_idx, acc, current_block_size): # Materialize the logits for the current block lm_head_b = pred_lm_head[Label, hax.dslice(start, Block)] # [Contract, Block] - logits_b = hax.dot( - pred_embeddings, lm_head_b, axis=Contract, preferred_element_type=dtype - ) # [Batch, Seq, Block] + logits_b = hax.dot(pred_embeddings, lm_head_b, axis=Contract) # [Batch, Seq, Block] + + if dtype is not None: + logits_b = logits_b.astype(dtype) # Update max and logsumexp max_logit = hax.maximum(max_logit_prev, hax.max(logits_b, axis=Block)) # [Batch, Seq] @@ -278,7 +281,7 @@ def process_block(block_idx, acc, current_block_size): # Update sumV. This is actually unnecessary if we're using one-hot targets # sV = sV_prev + hax.sum(target_y_b, axis=Label.name) - loss += hax.dot(logits_b, target_y_b, axis=Block, preferred_element_type=dtype) # [Batch, Seq] + loss += hax.dot(logits_b, target_y_b, axis=Block) # [Batch, Seq] return loss, logsumexp, max_logit # , sV @@ -351,7 +354,7 @@ def _block_cross_entropy_backward( num_blocks = vocab_size // block_size grad_embeddings = hax.zeros(pred_embeddings.axes, dtype=pred_embeddings.dtype) - grad_lm_head = hax.zeros(pred_lm_head.axes, dtype=pred_embeddings.dtype) + grad_lm_head = hax.zeros(pred_lm_head.axes, dtype=pred_lm_head.dtype) def process_block(block_idx, acc, current_block_size): """ @@ -372,14 +375,15 @@ def process_block(block_idx, acc, current_block_size): # Materialize the logits for the current block lm_head_b = pred_lm_head[Label, hax.dslice(start, Block)] # [Contract, Block] - logits_b = hax.dot( - pred_embeddings, lm_head_b, axis=Contract, preferred_element_type=dtype - ) # [Batch, Seq, Block] + logits_b = hax.dot(pred_embeddings, lm_head_b, axis=Contract) # [Batch, Seq, Block] # Materialize the target for the current block (one-hot) target_y_block = _block_one_hot(Block, start, labels_y, logits_b.dtype) # [Batch, Seq, Block] # materialize the softmax for the current block + if dtype is not None: + logits_b = logits_b.astype(dtype) + p_b = hax.exp(logits_b - log_z) # [Batch, Seq, Block] delta_b = p_b - target_y_block diff --git a/src/levanter/models/mpt.py b/src/levanter/models/mpt.py index 0809d9d23..97b61f1dc 100644 --- a/src/levanter/models/mpt.py +++ b/src/levanter/models/mpt.py @@ -107,7 +107,7 @@ def from_hf(config: HfMptAttentionConfig): @LmConfig.register_subclass("mpt") -@dataclass +@dataclass(frozen=True) class MptConfig(HFCompatConfig): d_model: int = 768 n_heads: int = 12