diff --git a/distributed/client.py b/distributed/client.py index be89e2cf8d..c93f5ff8e5 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -70,7 +70,7 @@ from tornado import gen from tornado.ioloop import IOLoop -from dask._task_spec import DataNode, GraphNode, Task, TaskRef +from dask._task_spec import DataNode, GraphNode, Task, TaskRef, parse_input import distributed.utils from distributed import cluster_dump, preloading @@ -622,6 +622,9 @@ def __await__(self): def __hash__(self): return hash(self._id) + def __eq__(self, other): + return self is other + class FutureState: """A Future's internal state. @@ -850,12 +853,10 @@ def __init__( **kwargs, ): self.func: Callable = func - self.iterables: Iterable[Any] = ( - list(zip(*zip(*iterables))) if _is_nested(iterables) else [iterables] - ) + self.iterables = [tuple(map(parse_input, iterable)) for iterable in iterables] self.key: str | Iterable[str] | None = key self.pure: bool = pure - self.kwargs = kwargs + self.kwargs = {k: parse_input(v) for k, v in kwargs.items()} super().__init__(annotations=annotations) def __repr__(self) -> str: @@ -2163,13 +2164,12 @@ def submit( if isinstance(workers, (str, Number)): workers = [workers] - dsk = { key: Task( key, func, - *args, - **kwargs, + *(parse_input(a) for a in args), + **{k: parse_input(v) for k, v in kwargs.items()}, ) } futures = self._graph_to_futures( diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index c782d56e38..010d7d880d 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -25,7 +25,7 @@ from tornado.ioloop import IOLoop import dask.config -from dask._task_spec import Task, _inline_recursively +from dask._task_spec import Task from dask.core import flatten from dask.typing import Key from dask.utils import parse_bytes, parse_timedelta @@ -569,7 +569,7 @@ def _mean_shard_size(shards: Iterable) -> int: return size // count if count else 0 -def p2p_barrier(id: ShuffleId, run_ids: list[int]) -> int: +def p2p_barrier(id: ShuffleId, *run_ids: int) -> int: try: return get_worker_plugin().barrier(id, run_ids) except Reschedule as e: @@ -599,18 +599,5 @@ def __init__( self.spec = spec super().__init__(key, func, *args, **kwargs) - def copy(self) -> P2PBarrierTask: - return P2PBarrierTask( - self.key, self.func, *self.args, spec=self.spec, **self.kwargs - ) - def __repr__(self) -> str: return f"P2PBarrierTask({self.key!r})" - - def inline(self, dsk: dict[Key, Any]) -> P2PBarrierTask: - new_args = _inline_recursively(self.args, dsk) - new_kwargs = _inline_recursively(self.kwargs, dsk) - assert self.func is not None - return P2PBarrierTask( - self.key, self.func, *new_args, spec=self.spec, **new_kwargs - ) diff --git a/distributed/shuffle/_merge.py b/distributed/shuffle/_merge.py index c7e62d5558..bf7a532e80 100644 --- a/distributed/shuffle/_merge.py +++ b/distributed/shuffle/_merge.py @@ -418,7 +418,7 @@ def _construct_graph(self) -> _T_LowLevelGraph: _barrier_key_left, p2p_barrier, token_left, - transfer_keys_left, + *transfer_keys_left, spec=DataFrameShuffleSpec( id=shuffle_id_left, npartitions=self.npartitions, @@ -435,7 +435,7 @@ def _construct_graph(self) -> _T_LowLevelGraph: _barrier_key_right, p2p_barrier, token_right, - transfer_keys_right, + *transfer_keys_right, spec=DataFrameShuffleSpec( id=shuffle_id_right, npartitions=self.npartitions, diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index 354828415b..8b597f79e4 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -121,7 +121,7 @@ import dask import dask.config -from dask._task_spec import Task, TaskRef +from dask._task_spec import Task, TaskRef, parse_input from dask.highlevelgraph import HighLevelGraph from dask.layers import Layer from dask.tokenize import tokenize @@ -756,7 +756,9 @@ def partial_concatenate( rec_cat_arg[old_partial_index] = TaskRef((input_name,) + old_global_index) concat_task = Task( - (rechunk_name(token),) + global_new_index, concatenate3, rec_cat_arg.tolist() + (rechunk_name(token),) + global_new_index, + concatenate3, + parse_input(rec_cat_arg.tolist()), ) dsk[concat_task.key] = concat_task return dsk @@ -822,7 +824,7 @@ def partial_rechunk( _barrier_key, p2p_barrier, partial_token, - transfer_keys, + *transfer_keys, spec=ArrayRechunkSpec( id=ShuffleId(partial_token), new=partial_new, old=partial_old, disk=disk ), diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index 912259e4a0..6f78a65042 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -275,7 +275,7 @@ def _construct_graph(self) -> _T_LowLevelGraph: _barrier_key, p2p_barrier, token, - transfer_keys, + *transfer_keys, spec=DataFrameShuffleSpec( id=shuffle_id, npartitions=self.npartitions, diff --git a/distributed/tests/test_as_completed.py b/distributed/tests/test_as_completed.py index 27b5b18c94..b2865a3695 100644 --- a/distributed/tests/test_as_completed.py +++ b/distributed/tests/test_as_completed.py @@ -131,20 +131,21 @@ def test_as_completed_is_empty(client): assert ac.is_empty() -def test_as_completed_cancel(client): - x = client.submit(inc, 1) - y = client.submit(inc, 1) +@gen_cluster(client=True) +async def test_as_completed_cancel(c, s, a, b): + x = c.submit(inc, 1) + y = c.submit(inc, 1) ac = as_completed([x, y]) - x.cancel() + await x.cancel() - assert next(ac) is x or y - assert next(ac) is y or x + async for fut in ac: + assert fut is y or fut is x with pytest.raises(queue.Empty): ac.queue.get(timeout=0.1) - res = list(as_completed([x, y, x])) + res = [fut async for fut in as_completed([x, y, x])] assert len(res) == 3 assert set(res) == {x, y} assert res.count(x) == 2