diff --git a/src/levanter/data/dataset.py b/src/levanter/data/dataset.py index 4d71241d4..f448ed83b 100644 --- a/src/levanter/data/dataset.py +++ b/src/levanter/data/dataset.py @@ -2,7 +2,7 @@ import asyncio import logging from concurrent.futures import ThreadPoolExecutor -from typing import Callable, Generic, Optional, Sequence, TypeVar +from typing import Callable, Generic, Optional, Sequence, TypeAlias, TypeVar import jax.random import numpy as np @@ -18,6 +18,11 @@ T = TypeVar("T") U = TypeVar("U") +# When we decide to standardize on 3.12, we can use fancier things +# P = ParamSpec("P") + +MapFunction: TypeAlias = Callable[..., U] + _executor = ThreadPoolExecutor(max_workers=10) @@ -111,9 +116,12 @@ def as_sync_dataset(self): def as_async_dataset(self) -> "AsyncDataset[T_co]": return self - def map(self, fn: Callable[[T_co], U], *extra_args, **extra_kwargs) -> "MappedAsyncDataset[T_co, U]": + def map(self, fn: MapFunction[U], *extra_args, **extra_kwargs) -> "MappedAsyncDataset[T_co, U]": return MappedAsyncDataset(self, fn, *extra_args, **extra_kwargs) + def map_batches(self, fn: MapFunction[Sequence[U]], *extra_args, **extra_kwargs) -> "BatchMappedAsyncDataset[U]": + return BatchMappedAsyncDataset(self, fn, *extra_args, **extra_kwargs) + def shuffle(self, key: PRNGKey): import levanter.data.permutation as permutation @@ -321,7 +329,7 @@ class MappedAsyncDataset(AsyncDataset[U], Generic[T, U]): def __init__( self, dataset: AsyncDataset[T], - fn: Callable[[T], U] | Callable[[T, Optional[PRNGKey]], U], + fn: MapFunction[U], *extra_args, **extra_kwargs, ): @@ -365,3 +373,65 @@ def _call_fn(self, index, item): else: kwargs = self._extra_kwargs return self.fn(item, *self._extra_args, **kwargs) + + +class BatchMappedAsyncDataset(AsyncDataset[U]): + """ + A dataset that applies a function to each batch of items in the dataset. + You can pass extra arguments to the function using `*extra_args` and `**extra_kwargs`. + If a kwarg called `key` is passed, it will be treated as a PRNGKey and folded in with the index of the item + for each call to the function. The key will be split into a key for each item in the batch. + """ + + def __init__( + self, + dataset: AsyncDataset[T], + fn: MapFunction[Sequence[U]], + *extra_args, + **extra_kwargs, + ): + super().__init__() + self.dataset = dataset + self.fn = fn + self._extra_args = extra_args + self._extra_kwargs = extra_kwargs + + async def async_len(self) -> int: + return await self.dataset.async_len() + + async def final_length_is_known(self) -> bool: + return await self.dataset.final_length_is_known() + + def is_finite(self) -> bool: + return self.dataset.is_finite() + + async def current_len(self) -> Optional[int]: + return await self.dataset.current_len() + + def _maybe_fold_in_key(self, key, indices: Sequence[int]): + if key is not None: + key = _fold_in_key_vmap(key, np.array(indices)) + return key + + async def get_batch(self, indices: Sequence[int]) -> Sequence[U]: + items = await self.dataset.get_batch(indices) + return self._call_fn(indices, items) + + async def getitem_async(self, index: int) -> U: + return self._call_fn([index], [await self.dataset.getitem_async(index)])[0] + + async def wait_until_len_at_least(self, length: int) -> int: + return await self.dataset.wait_until_len_at_least(length) + + def _call_fn(self, indices: Sequence[int], items): + if "key" in self._extra_kwargs: + key = self._maybe_fold_in_key(self._extra_kwargs["key"], indices) + kwargs = {**self._extra_kwargs, "key": key} + else: + kwargs = self._extra_kwargs + return self.fn(items, *self._extra_args, **kwargs) + + +@jax.jit +def _fold_in_key_vmap(key, indices): + return jax.vmap(lambda i: jax.random.fold_in(key, i))(indices) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 3e74c96b7..053372207 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -779,9 +779,9 @@ def _preprocess_supervised_example( } -def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase, Pos: hax.Axis) -> LmExample: +def _prepare_supervised_examples(ex: list[dict], tokenizer: PreTrainedTokenizerBase, Pos: hax.Axis) -> list[LmExample]: """ - Prepare an example for training. This function converts the (cached) batch encoding into an LmExample. + Prepare examples for training. This function converts the (cached) encodings into an LmExample. It goes through the following steps: @@ -789,26 +789,35 @@ def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase, Po 2. Mask out the input and prompt if requested. 3. Create an LmExample with the input_ids as the input and the next token as the target. """ - # annoyingly, pad expects things to be batched so we have to prepend a batch axis - ex = tokenizer.pad( - {k: np.expand_dims(v, 0) for k, v in ex.items()}, - return_tensors="np", + lens = np.array([ex["sources_len"] for ex in ex]) + + ex_pad = tokenizer.pad( + ex, padding="max_length", max_length=Pos.size, ) - ex = {k: v[0] for k, v in ex.items()} - # padding doesn't do truncation, so we have to do it ourselves. - # Truncate from the left since we want to predict the last tokens - input_ids = hax.named(ex["input_ids"][-Pos.size :], Pos) - # mask out padding and anything before the start of the target - loss_mask = hax.arange(Pos) >= ex["sources_len"] - 1 + input_ids = ex_pad["input_ids"] + truncated = [ids[-Pos.size :] for ids in input_ids] + + out = [] + for ids, len in zip(truncated, lens): + causal = _mk_sup_example_jit(Pos, hax.named(ids, Pos), len, tokenizer.pad_token_id) + + out.append(causal) + + return out + + +@functools.partial(jax.jit, static_argnums=(0, 3)) +def _mk_sup_example_jit(Pos, input_ids: hax.NamedArray, sources_len, pad_token_id): + # mask out padding and anything before the start of the target + loss_mask = hax.arange(Pos) >= sources_len - 1 # don't predict the padding targets = hax.roll(input_ids, -1, Pos) - loss_mask = loss_mask & (targets != tokenizer.pad_token_id) + loss_mask = loss_mask & (targets != pad_token_id) loss_mask = loss_mask & (1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.bool_)) - lm_ex = LmExample.causal(input_ids, loss_mask=loss_mask) - return lm_ex + return LmExample.causal(input_ids, loss_mask=loss_mask) def mk_supervised_datasets( @@ -884,7 +893,7 @@ def mk_supervised_dataset( if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer, Pos)) + return cached_dataset.map_batches(lambda ex: _prepare_supervised_examples(ex, tokenizer, Pos)) def _cache_supervised_set(source, cache_dir, tokenizer, Pos, input_field, output_field): @@ -899,7 +908,7 @@ def _cache_supervised_set(source, cache_dir, tokenizer, Pos, input_field, output output_exemplar=output_exemplar, ) cached_dataset: AsyncDataset[dict] = dataset.build_or_load_cache(cache_dir, await_finished=True) - ds = cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer, Pos)) + ds = cached_dataset.map_batches(lambda ex: _prepare_supervised_examples(ex, tokenizer, Pos)) return ds @@ -994,7 +1003,7 @@ def mk_chat_sft_dataset( tokenizer.pad_token = tokenizer.eos_token # Reuse the supervised prepare function directly - return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer, Pos)) + return cached_dataset.map_batches(lambda ex: _prepare_supervised_examples(ex, tokenizer, Pos)) @dataclass diff --git a/tests/test_supervised.py b/tests/test_supervised.py index 23f9e240c..40a3d927b 100644 --- a/tests/test_supervised.py +++ b/tests/test_supervised.py @@ -4,7 +4,7 @@ import haliax from haliax import Axis -from levanter.data.text import _prepare_supervised_example, _preprocess_supervised_example +from levanter.data.text import _prepare_supervised_examples, _preprocess_supervised_example def test_supervised_eval(): @@ -77,7 +77,7 @@ def test_supervised_eval(): "sources_len": np.array(45, dtype=np.int32), } - lm_ex = _prepare_supervised_example(ex, tokenizer, Axis("position", 128)) + lm_ex = _prepare_supervised_examples([ex], tokenizer, Axis("position", 128))[0] assert lm_ex.loss_mask["position", 44] assert haliax.sum(lm_ex.loss_mask) == 1