From 0660dab29c9606617e861ba49154cdbc047833f0 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 26 Nov 2024 11:56:41 +0100 Subject: [PATCH] Properly convert finalize dependencies to references (#8949) --- distributed/client.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index c93f5ff8e5..bba5be19bb 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -47,7 +47,7 @@ from dask.layers import Layer from dask.optimization import SubgraphCallable from dask.tokenize import tokenize -from dask.typing import Key, NoDefault, no_default +from dask.typing import Key, NestedKeys, NoDefault, no_default from dask.utils import ( ensure_dict, format_bytes, @@ -70,7 +70,7 @@ from tornado import gen from tornado.ioloop import IOLoop -from dask._task_spec import DataNode, GraphNode, Task, TaskRef, parse_input +from dask._task_spec import DataNode, GraphNode, List, Task, TaskRef, parse_input import distributed.utils from distributed import cluster_dump, preloading @@ -3675,7 +3675,8 @@ def compute( if func is single_key and len(keys) == 1 and not extra_args: names[i] = keys[0] else: - dsk2[name] = (func, keys) + extra_args + t = Task(name, func, _convert_dask_keys(keys), *extra_args) + dsk2[t.key] = t if not isinstance(dsk, HighLevelGraph): dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=()) @@ -5617,6 +5618,17 @@ def unforward_logging(self, logger_name=None): return self.unregister_worker_plugin(plugin_name) +def _convert_dask_keys(keys: NestedKeys) -> List: + assert isinstance(keys, list) + new_keys: list[List | TaskRef] = [] + for key in keys: + if isinstance(key, list): + new_keys.append(_convert_dask_keys(key)) + else: + new_keys.append(TaskRef(key)) + return List(*new_keys) + + class _WorkerSetupPlugin(WorkerPlugin): """This is used to support older setup functions as callbacks"""