Skip to content

Commit

Permalink
Properly convert finalize dependencies to references (#8949)
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait authored Nov 26, 2024
1 parent b57cb2c commit 0660dab
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from dask.layers import Layer
from dask.optimization import SubgraphCallable
from dask.tokenize import tokenize
from dask.typing import Key, NoDefault, no_default
from dask.typing import Key, NestedKeys, NoDefault, no_default
from dask.utils import (
ensure_dict,
format_bytes,
Expand All @@ -70,7 +70,7 @@
from tornado import gen
from tornado.ioloop import IOLoop

from dask._task_spec import DataNode, GraphNode, Task, TaskRef, parse_input
from dask._task_spec import DataNode, GraphNode, List, Task, TaskRef, parse_input

import distributed.utils
from distributed import cluster_dump, preloading
Expand Down Expand Up @@ -3675,7 +3675,8 @@ def compute(
if func is single_key and len(keys) == 1 and not extra_args:
names[i] = keys[0]
else:
dsk2[name] = (func, keys) + extra_args
t = Task(name, func, _convert_dask_keys(keys), *extra_args)
dsk2[t.key] = t

if not isinstance(dsk, HighLevelGraph):
dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=())
Expand Down Expand Up @@ -5617,6 +5618,17 @@ def unforward_logging(self, logger_name=None):
return self.unregister_worker_plugin(plugin_name)


def _convert_dask_keys(keys: NestedKeys) -> List:
assert isinstance(keys, list)
new_keys: list[List | TaskRef] = []
for key in keys:
if isinstance(key, list):
new_keys.append(_convert_dask_keys(key))
else:
new_keys.append(TaskRef(key))
return List(*new_keys)


class _WorkerSetupPlugin(WorkerPlugin):
"""This is used to support older setup functions as callbacks"""

Expand Down

0 comments on commit 0660dab

Please sign in to comment.