From 28804b72acc210935c061a0d68d46d4a6ae50a94 Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 21 Jun 2024 15:26:58 +0200 Subject: [PATCH] set restrictions during barrier on scheduler --- distributed/scheduler.py | 4 +- distributed/shuffle/_scheduler_plugin.py | 55 ++++++++++++++++++++++-- 2 files changed, 55 insertions(+), 4 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 4cba667d9e3..f6a36e15e00 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -7867,7 +7867,9 @@ def get_metadata(self, keys: list[Key], default: Any = no_default) -> Any: else: raise - def set_restrictions(self, worker: dict[Key, Collection[str] | str | None]) -> None: + def set_restrictions( + self, worker: Mapping[Key, Collection[str] | str | None] + ) -> None: for key, restrictions in worker.items(): ts = self.tasks[key] if isinstance(restrictions, str): diff --git a/distributed/shuffle/_scheduler_plugin.py b/distributed/shuffle/_scheduler_plugin.py index c6fbbe210a1..bdfd46d84dd 100644 --- a/distributed/shuffle/_scheduler_plugin.py +++ b/distributed/shuffle/_scheduler_plugin.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import contextlib import itertools import logging @@ -95,11 +96,59 @@ async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None: self.scheduler, stimulus_id=f"p2p-barrier-inconsistent-{time()}", ) + msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id} - await self.scheduler.broadcast( - msg=msg, - workers=list(shuffle.participating_workers), + broadcast_task = asyncio.create_task( + self.scheduler.broadcast( + msg=msg, + workers=list(shuffle.participating_workers), + ) ) + barrier_task = self.scheduler.tasks[barrier_key(id)] + barrier_deps: set[TaskState] = barrier_task.dependents + from dask.optimization import SubgraphCallable + + from distributed.shuffle._rechunk import rechunk_unpack + from distributed.shuffle._shuffle import shuffle_unpack + + def _extract_part_id(run_spec: Any) -> Any: + if not isinstance(run_spec, tuple): + return False + # FIXME: This is extremely crude. The shuffle run / spec should + # likely expose a method that is performing this check and returns + # the ID if possible. + if run_spec[0] is rechunk_unpack or run_spec[0] is shuffle_unpack: + # Happy path, we're just having the unfused dependencies + if len(run_spec) == 4: + return run_spec[2] + return run_spec[1][1] + elif isinstance(run_spec, SubgraphCallable): + for tspec in run_spec.dsk.values(): + if partial_index := _extract_part_id(tspec): + return partial_index + return False + else: + if any(ret := _extract_part_id(arg) for arg in run_spec): + return ret + return False + + restrictions = {} + for dep in barrier_deps: + # Ensure the broadcast can proceed as needed without blocking + # the event loop here + if not broadcast_task.done(): + await asyncio.sleep(0) + if partial_index := _extract_part_id(dep.run_spec): + worker = shuffle.run_spec.worker_for[partial_index] + restrictions[dep.key] = {worker} + else: + raise RuntimeError("Could not parse barrier dependents") + + await broadcast_task + # Set the restrictions after the barrier to not mess with concurrency + # control (the state before the barrier set restrictions was well tested + # and it is unclear if changing restrictions earlier would impact this) + self.scheduler.set_restrictions(restrictions) def restrict_task( self, id: ShuffleId, run_id: int, key: Key, worker: str