Skip to content

Commit

Permalink
Add support for batching inputs to avoid overwhelming the backend ser…
Browse files Browse the repository at this point in the history
…vice
  • Loading branch information
tomwhite committed Aug 14, 2023
1 parent 6ad5601 commit d348bcb
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 5 deletions.
21 changes: 20 additions & 1 deletion cubed/runtime/executors/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, AsyncIterator, Callable, Dict, Iterable, List, Optional, Tuple

from cubed.runtime.backup import should_launch_backup
from cubed.runtime.utils import batched


async def async_map_unordered(
Expand All @@ -14,6 +15,7 @@ async def async_map_unordered(
create_backup_futures_func: Optional[
Callable[..., List[Tuple[Any, Future]]]
] = None,
batch_size: Optional[int] = None,
return_stats: bool = False,
name: Optional[str] = None,
**kwargs,
Expand All @@ -25,8 +27,14 @@ async def async_map_unordered(
if create_backup_futures_func is None:
create_backup_futures_func = create_futures_func

if batch_size is None:
inputs = input
else:
input_batches = batched(input, batch_size)
inputs = next(input_batches)

task_create_tstamp = time.time()
tasks = {task: i for i, task in create_futures_func(input, **kwargs)}
tasks = {task: i for i, task in create_futures_func(inputs, **kwargs)}
pending = set(tasks.keys())
t = time.monotonic()
start_times = {f: t for f in pending}
Expand Down Expand Up @@ -81,3 +89,14 @@ async def async_map_unordered(
pending.add(new_task)
backups[task] = new_task
backups[new_task] = task

if batch_size is not None and len(pending) < batch_size:
inputs = next(input_batches, None) # type: ignore
if inputs is not None:
new_tasks = {
task: i for i, task in create_futures_func(inputs, **kwargs)
}
tasks.update(new_tasks)
pending.update(new_tasks.keys())
t = time.monotonic()
start_times = {f: t for f in new_tasks.keys()}
2 changes: 2 additions & 0 deletions cubed/runtime/executors/dask_distributed_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ async def map_unordered(
map_iterdata: Iterable[Union[List[Any], Tuple[Any, ...], Dict[str, Any]]],
retries: int = 2,
use_backups: bool = False,
batch_size: Optional[int] = None,
return_stats: bool = False,
name: Optional[str] = None,
**kwargs,
Expand Down Expand Up @@ -68,6 +69,7 @@ def create_backup_futures_func(input, **kwargs):
map_iterdata,
use_backups=use_backups,
create_backup_futures_func=create_backup_futures_func,
batch_size=batch_size,
return_stats=return_stats,
name=name,
**kwargs,
Expand Down
4 changes: 3 additions & 1 deletion cubed/runtime/executors/modal_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ async def map_unordered(
input: Iterable[Any],
use_backups: bool = False,
backup_function: Optional[Function] = None,
batch_size: Optional[int] = None,
return_stats: bool = False,
name: Optional[str] = None,
**kwargs,
Expand All @@ -44,7 +45,7 @@ async def map_unordered(
:return: Function values (and optionally stats) as they are completed, not necessarily in the input order.
"""

if not use_backups:
if not use_backups and batch_size is None:
task_create_tstamp = time.time()
async for result in app_function.map(input, order_outputs=False, kwargs=kwargs):
if return_stats:
Expand Down Expand Up @@ -76,6 +77,7 @@ def create_backup_futures_func(input, **kwargs):
input,
use_backups=use_backups,
create_backup_futures_func=create_backup_futures_func,
batch_size=batch_size,
return_stats=return_stats,
name=name,
**kwargs,
Expand Down
2 changes: 2 additions & 0 deletions cubed/runtime/executors/python_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ async def map_unordered(
input: Iterable[Any],
retries: int = 2,
use_backups: bool = False,
batch_size: Optional[int] = None,
return_stats: bool = False,
name: Optional[str] = None,
**kwargs,
Expand All @@ -54,6 +55,7 @@ def create_futures_func(input, **kwargs):
create_futures_func,
input,
use_backups=use_backups,
batch_size=batch_size,
return_stats=return_stats,
name=name,
**kwargs,
Expand Down
11 changes: 11 additions & 0 deletions cubed/runtime/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import time
from functools import partial
from itertools import islice

from cubed.utils import peak_measured_mem

Expand Down Expand Up @@ -52,3 +53,13 @@ def handle_callbacks(callbacks, stats):
else:
event = TaskEndEvent(**stats)
[callback.on_task_end(event) for callback in callbacks]


# this will be in Python 3.12 https://docs.python.org/3.12/library/itertools.html#itertools.batched
def batched(iterable, n):
# batched('ABCDEFG', 3) --> ABC DEF G
if n < 1:
raise ValueError("n must be at least one")
it = iter(iterable)
while batch := tuple(islice(it, n)):
yield batch
19 changes: 18 additions & 1 deletion cubed/tests/runtime/test_dask_distributed_async.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import itertools
from functools import partial

import pytest
Expand All @@ -12,7 +13,7 @@
from cubed.runtime.executors.dask_distributed_async import map_unordered


async def run_test(function, input, retries, use_backups=False):
async def run_test(function, input, retries, use_backups=False, batch_size=None):
outputs = set()
async with Client(asynchronous=True) as client:
async for output in map_unordered(
Expand All @@ -21,6 +22,7 @@ async def run_test(function, input, retries, use_backups=False):
input,
retries=retries,
use_backups=use_backups,
batch_size=batch_size,
):
outputs.add(output)
return outputs
Expand Down Expand Up @@ -98,3 +100,18 @@ def test_stragglers(tmp_path, timing_map, n_tasks, retries):
assert outputs == set(range(n_tasks))

check_invocation_counts(tmp_path, timing_map, n_tasks, retries)


def test_batch(tmp_path):
# input is unbounded, so if entire input were consumed and not read
# in batches then it would never return, since it would never
# run the first (failing) input
with pytest.raises(RuntimeError):
asyncio.run(
run_test(
function=partial(deterministic_failure, tmp_path, {0: [-1]}),
input=itertools.count(),
retries=0,
batch_size=10,
)
)
27 changes: 26 additions & 1 deletion cubed/tests/runtime/test_modal_async.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import itertools

import pytest

modal = pytest.importorskip("modal")
Expand Down Expand Up @@ -48,13 +50,14 @@ def deterministic_failure_modal_long_timeout(i, path=None, timing_map=None):
return deterministic_failure(path, timing_map, i)


async def run_test(app_function, input, use_backups=False, **kwargs):
async def run_test(app_function, input, use_backups=False, batch_size=None, **kwargs):
outputs = set()
async with stub.run():
async for output in map_unordered(
app_function,
input,
use_backups=use_backups,
batch_size=batch_size,
**kwargs,
):
outputs.add(output)
Expand Down Expand Up @@ -181,3 +184,25 @@ def test_stragglers(timing_map, n_tasks, retries, expected_invocation_counts_ove
finally:
fs = fsspec.open(tmp_path).fs
fs.rm(tmp_path, recursive=True)


@pytest.mark.cloud
def test_batch(tmp_path):
# input is unbounded, so if entire input were consumed and not read
# in batches then it would never return, since it would never
# run the first (failing) input
try:
with pytest.raises(RuntimeError):
asyncio.run(
run_test(
app_function=deterministic_failure_modal_no_retries,
input=itertools.count(),
path=tmp_path,
timing_map={0: [-1]},
batch_size=10,
)
)

finally:
fs = fsspec.open(tmp_path).fs
fs.rm(tmp_path, recursive=True)
19 changes: 18 additions & 1 deletion cubed/tests/runtime/test_python_async.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import itertools
from concurrent.futures import ThreadPoolExecutor
from functools import partial

Expand All @@ -8,7 +9,7 @@
from cubed.tests.runtime.utils import check_invocation_counts, deterministic_failure


async def run_test(function, input, retries=2, use_backups=False):
async def run_test(function, input, retries=2, use_backups=False, batch_size=None):
outputs = set()
concurrent_executor = ThreadPoolExecutor()
try:
Expand All @@ -18,6 +19,7 @@ async def run_test(function, input, retries=2, use_backups=False):
input,
retries=retries,
use_backups=use_backups,
batch_size=batch_size,
):
outputs.add(output)
finally:
Expand Down Expand Up @@ -95,3 +97,18 @@ def test_stragglers(tmp_path, timing_map, n_tasks, retries):
assert outputs == set(range(n_tasks))

check_invocation_counts(tmp_path, timing_map, n_tasks, retries)


def test_batch(tmp_path):
# input is unbounded, so if entire input were consumed and not read
# in batches then it would never return, since it would never
# run the first (failing) input
with pytest.raises(RuntimeError):
asyncio.run(
run_test(
function=partial(deterministic_failure, tmp_path, {0: [-1]}),
input=itertools.count(),
retries=0,
batch_size=10,
)
)

0 comments on commit d348bcb

Please sign in to comment.