Skip to content

Commit

Permalink
Remove recursion in task spec (#8920)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter authored Nov 19, 2024
1 parent cc2584d commit 750cb91
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 36 deletions.
16 changes: 8 additions & 8 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
from tornado import gen
from tornado.ioloop import IOLoop

from dask._task_spec import DataNode, GraphNode, Task, TaskRef
from dask._task_spec import DataNode, GraphNode, Task, TaskRef, parse_input

import distributed.utils
from distributed import cluster_dump, preloading
Expand Down Expand Up @@ -622,6 +622,9 @@ def __await__(self):
def __hash__(self):
return hash(self._id)

def __eq__(self, other):
return self is other


class FutureState:
"""A Future's internal state.
Expand Down Expand Up @@ -850,12 +853,10 @@ def __init__(
**kwargs,
):
self.func: Callable = func
self.iterables: Iterable[Any] = (
list(zip(*zip(*iterables))) if _is_nested(iterables) else [iterables]
)
self.iterables = [tuple(map(parse_input, iterable)) for iterable in iterables]
self.key: str | Iterable[str] | None = key
self.pure: bool = pure
self.kwargs = kwargs
self.kwargs = {k: parse_input(v) for k, v in kwargs.items()}
super().__init__(annotations=annotations)

def __repr__(self) -> str:
Expand Down Expand Up @@ -2163,13 +2164,12 @@ def submit(

if isinstance(workers, (str, Number)):
workers = [workers]

dsk = {
key: Task(
key,
func,
*args,
**kwargs,
*(parse_input(a) for a in args),
**{k: parse_input(v) for k, v in kwargs.items()},
)
}
futures = self._graph_to_futures(
Expand Down
17 changes: 2 additions & 15 deletions distributed/shuffle/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from tornado.ioloop import IOLoop

import dask.config
from dask._task_spec import Task, _inline_recursively
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
Expand Down Expand Up @@ -569,7 +569,7 @@ def _mean_shard_size(shards: Iterable) -> int:
return size // count if count else 0


def p2p_barrier(id: ShuffleId, run_ids: list[int]) -> int:
def p2p_barrier(id: ShuffleId, *run_ids: int) -> int:
try:
return get_worker_plugin().barrier(id, run_ids)
except Reschedule as e:
Expand Down Expand Up @@ -599,18 +599,5 @@ def __init__(
self.spec = spec
super().__init__(key, func, *args, **kwargs)

def copy(self) -> P2PBarrierTask:
return P2PBarrierTask(
self.key, self.func, *self.args, spec=self.spec, **self.kwargs
)

def __repr__(self) -> str:
return f"P2PBarrierTask({self.key!r})"

def inline(self, dsk: dict[Key, Any]) -> P2PBarrierTask:
new_args = _inline_recursively(self.args, dsk)
new_kwargs = _inline_recursively(self.kwargs, dsk)
assert self.func is not None
return P2PBarrierTask(
self.key, self.func, *new_args, spec=self.spec, **new_kwargs
)
4 changes: 2 additions & 2 deletions distributed/shuffle/_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def _construct_graph(self) -> _T_LowLevelGraph:
_barrier_key_left,
p2p_barrier,
token_left,
transfer_keys_left,
*transfer_keys_left,
spec=DataFrameShuffleSpec(
id=shuffle_id_left,
npartitions=self.npartitions,
Expand All @@ -435,7 +435,7 @@ def _construct_graph(self) -> _T_LowLevelGraph:
_barrier_key_right,
p2p_barrier,
token_right,
transfer_keys_right,
*transfer_keys_right,
spec=DataFrameShuffleSpec(
id=shuffle_id_right,
npartitions=self.npartitions,
Expand Down
8 changes: 5 additions & 3 deletions distributed/shuffle/_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@

import dask
import dask.config
from dask._task_spec import Task, TaskRef
from dask._task_spec import Task, TaskRef, parse_input
from dask.highlevelgraph import HighLevelGraph
from dask.layers import Layer
from dask.tokenize import tokenize
Expand Down Expand Up @@ -756,7 +756,9 @@ def partial_concatenate(
rec_cat_arg[old_partial_index] = TaskRef((input_name,) + old_global_index)

concat_task = Task(
(rechunk_name(token),) + global_new_index, concatenate3, rec_cat_arg.tolist()
(rechunk_name(token),) + global_new_index,
concatenate3,
parse_input(rec_cat_arg.tolist()),
)
dsk[concat_task.key] = concat_task
return dsk
Expand Down Expand Up @@ -822,7 +824,7 @@ def partial_rechunk(
_barrier_key,
p2p_barrier,
partial_token,
transfer_keys,
*transfer_keys,
spec=ArrayRechunkSpec(
id=ShuffleId(partial_token), new=partial_new, old=partial_old, disk=disk
),
Expand Down
2 changes: 1 addition & 1 deletion distributed/shuffle/_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def _construct_graph(self) -> _T_LowLevelGraph:
_barrier_key,
p2p_barrier,
token,
transfer_keys,
*transfer_keys,
spec=DataFrameShuffleSpec(
id=shuffle_id,
npartitions=self.npartitions,
Expand Down
15 changes: 8 additions & 7 deletions distributed/tests/test_as_completed.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,20 +131,21 @@ def test_as_completed_is_empty(client):
assert ac.is_empty()


def test_as_completed_cancel(client):
x = client.submit(inc, 1)
y = client.submit(inc, 1)
@gen_cluster(client=True)
async def test_as_completed_cancel(c, s, a, b):
x = c.submit(inc, 1)
y = c.submit(inc, 1)

ac = as_completed([x, y])
x.cancel()
await x.cancel()

assert next(ac) is x or y
assert next(ac) is y or x
async for fut in ac:
assert fut is y or fut is x

with pytest.raises(queue.Empty):
ac.queue.get(timeout=0.1)

res = list(as_completed([x, y, x]))
res = [fut async for fut in as_completed([x, y, x])]
assert len(res) == 3
assert set(res) == {x, y}
assert res.count(x) == 2
Expand Down

0 comments on commit 750cb91

Please sign in to comment.