diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index c0a86d830..9843877e2 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -196,7 +196,7 @@ async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]: len = await self.wait_until_len_at_least(max(indices) + 1) if len is not None and len < max(indices) + 1: raise ValueError("Requested indices beyond the end of the dataset") - offsets = np.array(indices) * self.seq_len + offsets = np.array(indices, dtype=np.int64) * self.seq_len with ts.Batch(): out = [] for offset in offsets: