diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index 48510dfd41..4f799f5072 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -203,18 +203,37 @@ async def barrier(self, run_ids: Sequence[int]) -> int: return self.run_id async def _send( - self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes + self, + address: str, + input_partitions: list[_T_partition_id], + output_partitions: list[_T_partition_id], + locs: list[_T_partition_id], + shards: list[Any] | bytes, ) -> OKMessage | ErrorMessage: self.raise_if_closed() return await self.rpc(address).shuffle_receive( + input_partitions=input_partitions, + output_partitions=output_partitions, + locs=locs, data=to_serialize(shards), shuffle_id=self.id, run_id=self.run_id, ) async def send( - self, address: str, shards: list[tuple[_T_partition_id, Any]] + self, address: str, sharded: list[tuple[_T_partition_id, Any]] ) -> OKMessage | ErrorMessage: + ipids = [] + opids = [] + locs = [] + shards = [] + for input_partition, inshards in sharded: + for output_partition, shard in inshards: + loc, data = shard + ipids.append(input_partition) + opids.append(output_partition) + locs.append(loc) + shards.append(data) if _mean_shard_size(shards) < 65536: # Don't send buffers individually over the tcp comms. # Instead, merge everything into an opaque bytes blob, send it all at once, @@ -226,7 +245,7 @@ async def send( shards_or_bytes = shards def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]: - return self._send(address, shards_or_bytes) + return self._send(address, ipids, opids, locs, shards_or_bytes) return await retry( _send, @@ -308,13 +327,17 @@ def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing return self._disk_buffer.read("_".join(str(i) for i in id)) async def receive( - self, data: list[tuple[_T_partition_id, Any]] | bytes + self, + input_partitions: list[_T_partition_id], + output_partitions: list[_T_partition_type], + locs: list[_T_partition_id], + data: list[Any] | bytes, ) -> OKMessage | ErrorMessage: try: if isinstance(data, bytes): # Unpack opaque blob. See send() - data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data)) - await self._receive(data) + data = cast(list[Any], pickle.loads(data)) + await self._receive(input_partitions, output_partitions, locs, data) return {"status": "OK"} except P2PConsistencyError as e: return error_message(e) @@ -336,7 +359,9 @@ def _get_assigned_worker(self, i: _T_partition_id) -> str: """Get the address of the worker assigned to the output partition""" @abc.abstractmethod - async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None: + async def _receive( + self, input_partitions: list[_T_partition_id], data: list[Any] + ) -> None: """Receive shards belonging to output partitions of this shuffle run""" def add_partition( diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index b33e90730b..1f78da3c36 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -658,20 +658,23 @@ def __init__( async def _receive( self, - data: list[tuple[NDIndex, list[tuple[NDIndex, tuple[NDIndex, np.ndarray]]]]], + input_partitions: list[NDIndex], + output_partitions: list[NDIndex], + locs: list[NDIndex], + data: list[np.ndarray], ) -> None: self.raise_if_closed() # Repartition shards and filter out already received ones shards = defaultdict(list) - for d in data: - id1, payload = d - if id1 in self.received: + for ipid, opid, loc, dat in zip(input_partitions, output_partitions, locs, data): + if ipid in self.received: continue - self.received.add(id1) - for id2, shard in payload: - shards[id2].append(shard) - self.total_recvd += sizeof(d) + shards[opid].append((loc, dat)) + self.total_recvd += sizeof(dat) + self.received.update(input_partitions) + del input_partitions + del output_partitions del data if not shards: return diff --git a/distributed/shuffle/_worker_plugin.py b/distributed/shuffle/_worker_plugin.py index 57d2cfe369..d59e660128 100644 --- a/distributed/shuffle/_worker_plugin.py +++ b/distributed/shuffle/_worker_plugin.py @@ -311,6 +311,9 @@ async def shuffle_receive( self, shuffle_id: ShuffleId, run_id: int, + input_partitions, + output_partitions, + locs, data: list[tuple[int, Any]] | bytes, ) -> OKMessage | ErrorMessage: """ @@ -319,7 +322,7 @@ async def shuffle_receive( """ try: shuffle_run = await self._get_shuffle_run(shuffle_id, run_id) - return await shuffle_run.receive(data) + return await shuffle_run.receive(input_partitions, output_partitions, locs, data) except P2PConsistencyError as e: return error_message(e)