Skip to content

Commit

Permalink
jit and batch supervised data loading to speed it up (a lot) (#816)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh authored Nov 21, 2024
1 parent 8509037 commit f44e5c8
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 23 deletions.
76 changes: 73 additions & 3 deletions src/levanter/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Generic, Optional, Sequence, TypeVar
from typing import Callable, Generic, Optional, Sequence, TypeAlias, TypeVar

import jax.random
import numpy as np
Expand All @@ -18,6 +18,11 @@
T = TypeVar("T")
U = TypeVar("U")

# When we decide to standardize on 3.12, we can use fancier things
# P = ParamSpec("P")

MapFunction: TypeAlias = Callable[..., U]


_executor = ThreadPoolExecutor(max_workers=10)

Expand Down Expand Up @@ -111,9 +116,12 @@ def as_sync_dataset(self):
def as_async_dataset(self) -> "AsyncDataset[T_co]":
return self

def map(self, fn: Callable[[T_co], U], *extra_args, **extra_kwargs) -> "MappedAsyncDataset[T_co, U]":
def map(self, fn: MapFunction[U], *extra_args, **extra_kwargs) -> "MappedAsyncDataset[T_co, U]":
return MappedAsyncDataset(self, fn, *extra_args, **extra_kwargs)

def map_batches(self, fn: MapFunction[Sequence[U]], *extra_args, **extra_kwargs) -> "BatchMappedAsyncDataset[U]":
return BatchMappedAsyncDataset(self, fn, *extra_args, **extra_kwargs)

def shuffle(self, key: PRNGKey):
import levanter.data.permutation as permutation

Expand Down Expand Up @@ -321,7 +329,7 @@ class MappedAsyncDataset(AsyncDataset[U], Generic[T, U]):
def __init__(
self,
dataset: AsyncDataset[T],
fn: Callable[[T], U] | Callable[[T, Optional[PRNGKey]], U],
fn: MapFunction[U],
*extra_args,
**extra_kwargs,
):
Expand Down Expand Up @@ -365,3 +373,65 @@ def _call_fn(self, index, item):
else:
kwargs = self._extra_kwargs
return self.fn(item, *self._extra_args, **kwargs)


class BatchMappedAsyncDataset(AsyncDataset[U]):
"""
A dataset that applies a function to each batch of items in the dataset.
You can pass extra arguments to the function using `*extra_args` and `**extra_kwargs`.
If a kwarg called `key` is passed, it will be treated as a PRNGKey and folded in with the index of the item
for each call to the function. The key will be split into a key for each item in the batch.
"""

def __init__(
self,
dataset: AsyncDataset[T],
fn: MapFunction[Sequence[U]],
*extra_args,
**extra_kwargs,
):
super().__init__()
self.dataset = dataset
self.fn = fn
self._extra_args = extra_args
self._extra_kwargs = extra_kwargs

async def async_len(self) -> int:
return await self.dataset.async_len()

async def final_length_is_known(self) -> bool:
return await self.dataset.final_length_is_known()

def is_finite(self) -> bool:
return self.dataset.is_finite()

async def current_len(self) -> Optional[int]:
return await self.dataset.current_len()

def _maybe_fold_in_key(self, key, indices: Sequence[int]):
if key is not None:
key = _fold_in_key_vmap(key, np.array(indices))
return key

async def get_batch(self, indices: Sequence[int]) -> Sequence[U]:
items = await self.dataset.get_batch(indices)
return self._call_fn(indices, items)

async def getitem_async(self, index: int) -> U:
return self._call_fn([index], [await self.dataset.getitem_async(index)])[0]

async def wait_until_len_at_least(self, length: int) -> int:
return await self.dataset.wait_until_len_at_least(length)

def _call_fn(self, indices: Sequence[int], items):
if "key" in self._extra_kwargs:
key = self._maybe_fold_in_key(self._extra_kwargs["key"], indices)
kwargs = {**self._extra_kwargs, "key": key}
else:
kwargs = self._extra_kwargs
return self.fn(items, *self._extra_args, **kwargs)


@jax.jit
def _fold_in_key_vmap(key, indices):
return jax.vmap(lambda i: jax.random.fold_in(key, i))(indices)
45 changes: 27 additions & 18 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,36 +779,45 @@ def _preprocess_supervised_example(
}


def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase, Pos: hax.Axis) -> LmExample:
def _prepare_supervised_examples(ex: list[dict], tokenizer: PreTrainedTokenizerBase, Pos: hax.Axis) -> list[LmExample]:
"""
Prepare an example for training. This function converts the (cached) batch encoding into an LmExample.
Prepare examples for training. This function converts the (cached) encodings into an LmExample.
It goes through the following steps:
1. Pad the batch to the maximum length.
2. Mask out the input and prompt if requested.
3. Create an LmExample with the input_ids as the input and the next token as the target.
"""
# annoyingly, pad expects things to be batched so we have to prepend a batch axis
ex = tokenizer.pad(
{k: np.expand_dims(v, 0) for k, v in ex.items()},
return_tensors="np",
lens = np.array([ex["sources_len"] for ex in ex])

ex_pad = tokenizer.pad(
ex,
padding="max_length",
max_length=Pos.size,
)
ex = {k: v[0] for k, v in ex.items()}
# padding doesn't do truncation, so we have to do it ourselves.
# Truncate from the left since we want to predict the last tokens
input_ids = hax.named(ex["input_ids"][-Pos.size :], Pos)
# mask out padding and anything before the start of the target
loss_mask = hax.arange(Pos) >= ex["sources_len"] - 1

input_ids = ex_pad["input_ids"]
truncated = [ids[-Pos.size :] for ids in input_ids]

out = []
for ids, len in zip(truncated, lens):
causal = _mk_sup_example_jit(Pos, hax.named(ids, Pos), len, tokenizer.pad_token_id)

out.append(causal)

return out


@functools.partial(jax.jit, static_argnums=(0, 3))
def _mk_sup_example_jit(Pos, input_ids: hax.NamedArray, sources_len, pad_token_id):
# mask out padding and anything before the start of the target
loss_mask = hax.arange(Pos) >= sources_len - 1
# don't predict the padding
targets = hax.roll(input_ids, -1, Pos)
loss_mask = loss_mask & (targets != tokenizer.pad_token_id)
loss_mask = loss_mask & (targets != pad_token_id)
loss_mask = loss_mask & (1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.bool_))
lm_ex = LmExample.causal(input_ids, loss_mask=loss_mask)
return lm_ex
return LmExample.causal(input_ids, loss_mask=loss_mask)


def mk_supervised_datasets(
Expand Down Expand Up @@ -884,7 +893,7 @@ def mk_supervised_dataset(
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer, Pos))
return cached_dataset.map_batches(lambda ex: _prepare_supervised_examples(ex, tokenizer, Pos))


def _cache_supervised_set(source, cache_dir, tokenizer, Pos, input_field, output_field):
Expand All @@ -899,7 +908,7 @@ def _cache_supervised_set(source, cache_dir, tokenizer, Pos, input_field, output
output_exemplar=output_exemplar,
)
cached_dataset: AsyncDataset[dict] = dataset.build_or_load_cache(cache_dir, await_finished=True)
ds = cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer, Pos))
ds = cached_dataset.map_batches(lambda ex: _prepare_supervised_examples(ex, tokenizer, Pos))
return ds


Expand Down Expand Up @@ -994,7 +1003,7 @@ def mk_chat_sft_dataset(
tokenizer.pad_token = tokenizer.eos_token

# Reuse the supervised prepare function directly
return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer, Pos))
return cached_dataset.map_batches(lambda ex: _prepare_supervised_examples(ex, tokenizer, Pos))


@dataclass
Expand Down
4 changes: 2 additions & 2 deletions tests/test_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import haliax
from haliax import Axis

from levanter.data.text import _prepare_supervised_example, _preprocess_supervised_example
from levanter.data.text import _prepare_supervised_examples, _preprocess_supervised_example


def test_supervised_eval():
Expand Down Expand Up @@ -77,7 +77,7 @@ def test_supervised_eval():
"sources_len": np.array(45, dtype=np.int32),
}

lm_ex = _prepare_supervised_example(ex, tokenizer, Axis("position", 128))
lm_ex = _prepare_supervised_examples([ex], tokenizer, Axis("position", 128))[0]

assert lm_ex.loss_mask["position", 44]
assert haliax.sum(lm_ex.loss_mask) == 1

0 comments on commit f44e5c8

Please sign in to comment.