Skip to content

Commit

Permalink
fix crash in data loader caused by using stale array (#765)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh authored Oct 14, 2024
1 parent fc26c74 commit 02f34ac
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 138 deletions.
235 changes: 124 additions & 111 deletions src/levanter/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import time
from collections import defaultdict
from typing import Iterable, Iterator, Optional, Tuple, TypeVar
from typing import AsyncIterator, Callable, Iterable, Iterator, Optional, Tuple, TypeVar

import jax
from jax import Array
Expand All @@ -20,8 +20,9 @@
from levanter.data.dataset import AsyncDataset
from levanter.data.utils import batched
from levanter.shapes import NamedShapeSpec, ShapeSpec, to_raw_shape
from levanter.utils.background_iterable import BackgroundIterable
from levanter.utils.thread_utils import blocking_wait
from levanter.utils.background_iterable import BackgroundIterator
from levanter.utils.jax_utils import local_cpu_mesh
from levanter.utils.thread_utils import AsyncIteratorWrapper, blocking_wait


Ex = TypeVar("Ex")
Expand Down Expand Up @@ -62,10 +63,11 @@ def __init__(
self.mesh = mesh
self.Batch = Batch

def _exemplar_shape():
return blocking_wait(self.data_store.getitem_async(0))

self._ex_leaves, self._ex_structure = jax.tree_flatten(_exemplar_shape(), is_leaf=is_named_array)
with local_cpu_mesh():
# It's important that all data loading happens CPU side. We might relax this one day.
self._ex_leaves, self._ex_structure = jax.tree_flatten(
blocking_wait(self.data_store.getitem_async(0)), is_leaf=is_named_array
)

local_device_indices, local_indices = self._compute_local_device_indices()

Expand Down Expand Up @@ -98,6 +100,8 @@ def __iter__(self):
return self.iter_from_step(None)

def iter_from_step(self, start_from_batch: Optional[int] = None):
# sometimes we pass in an array for the start_from_batch, so we need to check for that
start_from_batch = int(start_from_batch) if start_from_batch is not None else None
return DataLoaderIterator(self, start_from_batch=start_from_batch)


Expand All @@ -109,115 +113,131 @@ def __init__(self, data_loader: DataLoader, start_from_batch: Optional[int] = No
if self.mapping is None:
self.mapping = hax.partitioning.current_thread_local_mapping()

# TODO: bring back non-prefetching version
buffered_batches = self.dl.max_buffered_batches
self._batches = iter(BackgroundIterable(self._produce_batches, max_capacity=buffered_batches))
self._batches: Iterator[Ex]
if buffered_batches == 0:
self._batches = AsyncIteratorWrapper(self._produce_batches())
else:
self._batches = _JaxCpuBackgroundIterator(self._produce_batches, max_capacity=buffered_batches)

def __next__(self):
time_start = time.time()
out = next(self._batches)
individual_data_batch = next(self._batches)
data_for_this_batch = {index: datum for index, datum in zip(self.dl._local_indices, individual_data_batch)}
batch = self._batchify_local_data(data_for_this_batch)

time_end = time.time()
if (time_end - time_start) > 0.5:
logger.info(f"Prefetch wasn't fast enough: {time_end - time_start:.3f}")
return out
return batch

async def _produce_batches(self):
batch_number = self._start_from_batch or 0
total_ex_loaded = 0
done = False
while not done:
next_batch_numbers = []
for i in range(self.dl.prefetch_size):
if self.dl.data_store.is_finite():
next_end = (batch_number + 1) * self.dl.batch_size
available_len = await self.dl.data_store.wait_until_len_at_least(next_end)
if available_len < next_end:
done = True
break

next_batch_numbers.append(batch_number)
batch_number += 1
target_next_batch_number = batch_number + self.dl.prefetch_size
max_achievable_batch_number = await self._dataset_get_available_batch_number(target_next_batch_number)
if max_achievable_batch_number < target_next_batch_number:
done = True

next_batch_numbers = list(range(batch_number, min(target_next_batch_number, max_achievable_batch_number)))

if len(next_batch_numbers) == 0:
break

batch_number = next_batch_numbers[-1] + 1

async for batch in self._retrieve_batches(next_batch_numbers):
yield batch

total_ex_loaded += self.dl.batch_size * len(next_batch_numbers)
async def _dataset_get_available_batch_number(self, target_max_batch_number: int) -> int:
if self.dl.data_store.is_finite():
next_end = (target_max_batch_number + 1) * self.dl.batch_size
available_len = await self.dl.data_store.wait_until_len_at_least(next_end)
max_achievable_batch_number = available_len // self.dl.batch_size

async def _retrieve_batches(self, batch_numbers: list[int]):
with hax.axis_mapping(self.mapping), self.dl.mesh:
indices_for_this_batch_of_batches: list[int] = []
for bn in batch_numbers:
indices_this_batch = range(bn * self.dl.batch_size, (bn + 1) * self.dl.batch_size, 1)
indices_this_batch_this_process = [indices_this_batch[i] for i in self.dl._local_indices]
indices_for_this_batch_of_batches.extend(indices_this_batch_this_process)
return max_achievable_batch_number

return target_max_batch_number

async def _retrieve_batches(self, batch_numbers: list[int]):
with local_cpu_mesh():
time_start = time.time()
individual_datums = await self.dl.data_store.get_batch(indices_for_this_batch_of_batches)
individual_datums_for_each_batch = await self._do_retrieve_batch_of_batches(batch_numbers)
# reshape to be per batch
time_end = time.time()
logger.debug(f"Time to get {len(batch_numbers)} batches: {time_end - time_start:.3f}")
time_start = time.time()
# reshape to be per batch
individual_datums = list(batched(individual_datums, len(self.dl._local_indices)))

# below we're gonna get the indices relative to this batch (i.e. 0 to batch_size)
index_to_datum = [
{index: datum for index, datum in zip(self.dl._local_indices, individual_data_batch)}
for individual_data_batch in individual_datums
]

def get_local_batch(bn: int, begin: int, end: int) -> list:
# TODO: if we ever do "big data" (i.e. huge examples) we might want to be able to load part of an example
# which will require support from the datastore (i.e. tensorstore)
device_batch = _stack_tree(self.dl.Batch.name, [index_to_datum[bn][i] for i in range(begin, end)])
batch_leaves = hax.tree_util.tree_leaves(device_batch)
return batch_leaves

def get_local_data_for_leaf(bn, indices: _TensorSliceIndex, leaf_index: int) -> Array:
batch_slice = indices[0]
begin, end, stride = batch_slice.indices(self.dl.batch_size)
if stride != 1:
raise ValueError("Stride must be 1")

leaf_data = (get_local_batch(bn, begin, end))[leaf_index]

if isinstance(leaf_data, hax.NamedArray):
# select out the batch axis
batch_index = index_where(lambda ax: ax.name == self.dl.Batch.name, leaf_data.axes)
new_indices = list(indices)
new_indices[batch_index] = slice(None)
return leaf_data.array[tuple(new_indices)]

for data in individual_datums_for_each_batch:
yield data

def _batchify_local_data(self, data_for_this_batch: dict[int, Array]):
cache: dict[tuple[int, int], list[Array | hax.NamedArray]] = {}

def get_local_batch(begin: int, end: int) -> list:
if (begin, end) in cache:
return cache[(begin, end)]

# TODO: if we ever do "big data" (i.e. huge examples) we might want to be able to load part of an example
# which will require support from the datastore (i.e. tensorstore)
device_batch = _stack_tree(self.dl.Batch.name, [data_for_this_batch[i] for i in range(begin, end)])
batch_leaves = hax.tree_util.tree_leaves(device_batch)

cache[(begin, end)] = batch_leaves

return batch_leaves

def get_local_data_for_leaf(indices: _TensorSliceIndex, leaf_index: int) -> Array:
batch_slice = indices[0]
begin, end, stride = batch_slice.indices(self.dl.batch_size)
if stride != 1:
raise ValueError("Stride must be 1")

leaf_data = get_local_batch(begin, end)[leaf_index]

if isinstance(leaf_data, hax.NamedArray):
# select out the batch axis
batch_index = index_where(lambda ax: ax.name == self.dl.Batch.name, leaf_data.axes)
new_indices = list(indices)
new_indices[batch_index] = slice(None)
return leaf_data.array[tuple(new_indices)]
else:
other_indices = indices[1:]
if all(idx == slice(None) for idx in other_indices):
return leaf_data
else:
other_indices = indices[1:]
if all(idx == slice(None) for idx in other_indices):
return leaf_data
else:
# TODO: this doesn't work with named axes
return leaf_data[(..., *other_indices)]

for batch_offset, bn in enumerate(batch_numbers):

def make_global_array_for_leaf(leaf_index, item_leaf_shape: ShapeSpec | NamedShapeSpec):
def get_data(indices):
return get_local_data_for_leaf(batch_offset, indices, leaf_index)

raw_array = jax.make_array_from_callback(
to_raw_shape(item_leaf_shape),
jax.sharding.NamedSharding(self.dl.mesh, self._pspec_for(item_leaf_shape)),
get_data,
)
if isinstance(item_leaf_shape, NamedShapeSpec):
return hax.NamedArray(raw_array, item_leaf_shape.shape)
else:
return raw_array

gda_leaves = [
make_global_array_for_leaf(leaf_index, _batchified_shape(self.dl.Batch, item_leaf))
for leaf_index, item_leaf in enumerate(self.dl._ex_leaves)
]

gda_tree = jax.tree.unflatten(self.dl._ex_structure, gda_leaves)
yield gda_tree
return leaf_data[(..., *other_indices)]

def make_global_array_for_leaf(leaf_index, item_leaf_shape: ShapeSpec | NamedShapeSpec):
def get_data(indices):
return get_local_data_for_leaf(indices, leaf_index)

raw_array = jax.make_array_from_callback(
to_raw_shape(item_leaf_shape),
jax.sharding.NamedSharding(self.dl.mesh, self._pspec_for(item_leaf_shape)),
get_data,
)
if isinstance(item_leaf_shape, NamedShapeSpec):
return hax.NamedArray(raw_array, item_leaf_shape.shape)
else:
return raw_array

gda_leaves = [
make_global_array_for_leaf(leaf_index, _batchified_shape(self.dl.Batch, item_leaf))
for leaf_index, item_leaf in enumerate(self.dl._ex_leaves)
]
gda_tree = jax.tree.unflatten(self.dl._ex_structure, gda_leaves)
return gda_tree

async def _do_retrieve_batch_of_batches(self, batch_numbers):
indices_for_this_batch_of_batches: list[int] = []
for bn in batch_numbers:
indices_this_batch = range(bn * self.dl.batch_size, (bn + 1) * self.dl.batch_size, 1)
indices_this_batch_this_process = [indices_this_batch[i] for i in self.dl._local_indices]
indices_for_this_batch_of_batches.extend(indices_this_batch_this_process)
individual_datums = await self.dl.data_store.get_batch(indices_for_this_batch_of_batches)
individual_datums_for_each_batch = list(batched(individual_datums, len(self.dl._local_indices)))
return individual_datums_for_each_batch

def _pspec_for(self, shape_spec: ShapeSpec | NamedShapeSpec) -> PartitionSpec:
if isinstance(shape_spec, ShapeSpec): # type: ignore
Expand All @@ -227,31 +247,24 @@ def _pspec_for(self, shape_spec: ShapeSpec | NamedShapeSpec) -> PartitionSpec:
return hax.partitioning.pspec_for_axis(shape_spec.shape, self.dl.axis_resources) # type: ignore


def _abstractify(x):
def _abstractify_array(x):
if isinstance(x, jax.numpy.ndarray):
return ShapeSpec(x.shape, x.dtype)
elif isinstance(x, hax.NamedArray):
return NamedShapeSpec(x.axes, x.dtype)

return x

return hax.tree_util.tree_map(_abstractify_array, x)


def _batchified_shape(Batch, leaf: hax.NamedArray | Array) -> ShapeSpec | NamedShapeSpec:
if is_named_array(leaf):
return NamedShapeSpec((Batch,) + leaf.axes, leaf.dtype)
else:
return ShapeSpec((Batch.size,) + leaf.shape, leaf.dtype)


def _pspec_for(self, shape_spec: ShapeSpec | NamedShapeSpec) -> PartitionSpec:
if isinstance(shape_spec, ShapeSpec): # type: ignore
batch_name = hax.partitioning.physical_axis_name(self.Batch, self.axis_resources)
return PartitionSpec(batch_name, *((None,) * (len(shape_spec.shape) - 1)))
else:
return hax.partitioning.pspec_for_axis(shape_spec.shape, self.axis_resources) # type: ignore
class _JaxCpuBackgroundIterator(BackgroundIterator[Ex]):
"""
We want the thread to only use the CPU device.
"""

def __init__(self, producer_fn: Callable[[], Iterator[Ex] | AsyncIterator[Ex]], max_capacity: Optional[int]):
super().__init__(producer_fn, max_capacity)

def _fill_queue_with_batches(self):
with local_cpu_mesh():
super()._fill_queue_with_batches()


@functools.partial(jax.jit, static_argnums=(0,))
Expand Down
58 changes: 31 additions & 27 deletions tests/test_doremi.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from levanter.data import AsyncDataset
from levanter.data.mixture import MixtureDataset
from levanter.trainer import Trainer, TrainerConfig
from levanter.utils.jax_utils import key_iterator
from levanter.utils.jax_utils import key_iterator, local_cpu_mesh
from levanter.utils.py_utils import non_caching_cycle


Expand All @@ -27,6 +27,15 @@ class Example(equinox.Module):
Block = hax.Axis("Block", 1024)


def platform_of_array(x):
if isinstance(x, jax.Array):
return set(d.platform for d in x.devices())
elif isinstance(x, hax.NamedArray):
return platform_of_array(x.array)
else:
return "cpu"


class LogitDataset(AsyncDataset[Example]):
def __init__(self, W, noise, x_mask, x_bias, *, key):
self.W = W
Expand All @@ -52,17 +61,12 @@ def _gen_block_data(block_id):

self._gen_block_data = _gen_block_data

def __iter__(self):
key_iter = key_iterator(self.key)
Dim = self.W.axes[0]
while True:
kk = next(key_iter)
this_key_iter = key_iterator(kk)
x_block = hax.random.normal(next(this_key_iter), (Block, Dim)) * self.x_mask + self.x_bias
noise = hax.random.normal(next(this_key_iter), (Block,)) * self.noise
y_block = (hax.nn.sigmoid(hax.dot(x_block, self.W, axis=Dim) + noise) > 0.5).astype(float)
for i in range(Block.size):
yield self._make_example(x_block, y_block, i)
def _make_block(self, Dim, kk):
this_key_iter = key_iterator(kk)
x_block = hax.random.normal(next(this_key_iter), (Block, Dim)) * self.x_mask + self.x_bias
noise = hax.random.normal(next(this_key_iter), (Block,)) * self.noise
y_block = (hax.nn.sigmoid(hax.dot(x_block, self.W, axis=Dim) + noise) > 0.5).astype(float)
return x_block, y_block

async def async_len(self) -> int:
raise ValueError("Infinitely long dataset")
Expand Down Expand Up @@ -106,21 +110,21 @@ def test_estimate_mixture_weights():
Dim = hax.Axis("Dim", 5)
Batch = hax.Axis("Batch", 32)

keys = key_iterator(0)

# W = hax.random.normal(next(keys), (Dim,))
W1 = hax.named([0.0, 0.5, 0.5, 0.0, 0.0], (Dim,))
x1_mask = hax.named([0.0, 1.0, 1.0, 0.0, 0.0], (Dim,))
W2 = hax.named([0.0, 0.0, 0.0, 0.0, 0.0], (Dim,))
x2_mask = hax.named([0.0, 0.0, 0.0, 1.0, 1.0], (Dim,))
W3 = hax.named([1.0, 0.0, 0.0, 0.0, 0.0], (Dim,))
x3_mask = hax.named([1.0, 0.0, 0.0, 0.0, 0.0], (Dim,))
x3_bias = hax.named([1.0, 0.0, 0.0, 0.0, 0.0], (Dim,))

# y = sigmoid(Wx + b + N(0, noise^2)) > 0.5
ds1 = LogitDataset(W1, 0.1, x1_mask, 0.0, key=next(keys))
ds2 = LogitDataset(W2, 2.0, x2_mask, 0.0, key=next(keys))
ds3 = LogitDataset(W3, 0.05, x3_mask, x3_bias, key=next(keys))
# data loading needs to take place on CPU
with local_cpu_mesh():
keys = key_iterator(0)
W1 = hax.named([0.0, 0.5, 0.5, 0.0, 0.0], (Dim,))
x1_mask = hax.named([0.0, 1.0, 1.0, 0.0, 0.0], (Dim,))
W2 = hax.named([0.0, 0.0, 0.0, 0.0, 0.0], (Dim,))
x2_mask = hax.named([0.0, 0.0, 0.0, 1.0, 1.0], (Dim,))
W3 = hax.named([1.0, 0.0, 0.0, 0.0, 0.0], (Dim,))
x3_mask = hax.named([1.0, 0.0, 0.0, 0.0, 0.0], (Dim,))
x3_bias = hax.named([1.0, 0.0, 0.0, 0.0, 0.0], (Dim,))

# y = sigmoid(Wx + b + N(0, noise^2)) > 0.5
ds1 = LogitDataset(W1, 0.1, x1_mask, 0.0, key=next(keys))
ds2 = LogitDataset(W2, 2.0, x2_mask, 0.0, key=next(keys))
ds3 = LogitDataset(W3, 0.05, x3_mask, x3_bias, key=next(keys))

# TODO: remove key as a requirement for models
def compute_loss_fn(model, example, reduction=hax.mean, reduction_axis=None, key=None):
Expand Down

0 comments on commit 02f34ac

Please sign in to comment.