Skip to content

Temporarily disable gpuCI update CI job (#8945) #15320

Temporarily disable gpuCI update CI job (#8945)

Temporarily disable gpuCI update CI job (#8945) #15320

GitHub Actions / Unit Test Results failed Nov 20, 2024 in 0s

43 fail, 112 skipped, 3 975 pass in 5h 36m 44s

    17 files   -      8      17 suites   - 8   5h 36m 44s ⏱️ - 4h 35m 53s
 4 130 tests ±     0   3 975 ✅ +     1    112 💤 +  2   43 ❌  -   3 
30 127 runs   - 17 565  28 512 ✅  - 16 618  1 383 💤  - 738  232 ❌  - 209 

Results for commit af77cfc. ± Comparison against earlier commit 750cb91.

Annotations

Check warning on line 0 in distributed.tests.test_client

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

8 out of 9 runs failed: test_persist_async (distributed.tests.test_client)

artifacts/macos-latest-3.12-default-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-ci1/pytest.xml [took 0s]
Raw output
IndexError: tuple index out of range
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:40231', workers: 0, cores: 0, tasks: 0>
a = <Worker 'tcp://127.0.0.1:34811', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>
b = <Worker 'tcp://127.0.0.1:36687', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>

    @gen_cluster(client=True)
    async def test_persist_async(c, s, a, b):
        pytest.importorskip("numpy")
        da = pytest.importorskip("dask.array")
    
        x = da.ones((10, 10), chunks=(5, 10))
        y = 2 * (x + 1)
        assert len(y.dask) == 6
        yy = c.persist(y)
    
        assert len(y.dask) == 6
        assert len(yy.dask) == 2
        assert all(isinstance(v, Future) for v in yy.dask.values())
        assert yy.__dask_keys__() == y.__dask_keys__()
    
        g, h = c.compute([y, yy])
    
>       gg, hh = await c.gather([g, h])

distributed/tests/test_client.py:2565: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:2427: in _gather
    raise exception.with_traceback(traceback)
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/array/core.py:1284: in finalize
    return concatenate3(results)
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/array/core.py:5452: in concatenate3
    chunks = chunks_from_arrays(arrays)
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/array/core.py:5237: in chunks_from_arrays
    result.append(tuple(shape(deepfirst(a))[dim] for a in arrays))
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import contextlib
    import math
    import operator
    import os
    import pickle
    import re
    import sys
    import traceback
    import uuid
    import warnings
    from bisect import bisect
    from collections import defaultdict
    from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence
    from functools import lru_cache, partial, reduce, wraps
    from itertools import product, zip_longest
    from numbers import Integral, Number
    from operator import add, mul
    from threading import Lock
    from typing import Any, Literal, TypeVar, Union, cast
    
    import numpy as np
    from numpy.typing import ArrayLike
    from packaging.version import Version
    from tlz import accumulate, concat, first, groupby, partition
    from tlz.curried import pluck
    from toolz import frequencies
    
    from dask import compute, config, core
    from dask.array import chunk
    from dask.array.chunk import getitem
    from dask.array.chunk_types import is_valid_array_chunk, is_valid_chunk_type
    
    # Keep einsum_lookup and tensordot_lookup here for backwards compatibility
    from dask.array.dispatch import (  # noqa: F401
        concatenate_lookup,
        einsum_lookup,
        tensordot_lookup,
    )
    from dask.array.numpy_compat import NUMPY_GE_200, _Recurser
    from dask.array.slicing import replace_ellipsis, setitem_array, slice_array
    from dask.array.utils import compute_meta, meta_from_array
    from dask.base import (
        DaskMethodsMixin,
        compute_as_if_collection,
        dont_optimize,
        is_dask_collection,
        named_schedulers,
        persist,
        tokenize,
    )
    from dask.blockwise import blockwise as core_blockwise
    from dask.blockwise import broadcast_dimensions
    from dask.context import globalmethod
    from dask.core import quote
    from dask.delayed import Delayed, delayed
    from dask.highlevelgraph import HighLevelGraph, MaterializedLayer
    from dask.layers import ArrayBlockIdDep, ArraySliceDep, ArrayValuesDep, reshapelist
    from dask.sizeof import sizeof
    from dask.typing import Graph, Key, NestedKeys
    from dask.utils import (
        IndexCallable,
        SerializableLock,
        cached_cumsum,
        cached_property,
        concrete,
        derived_from,
        format_bytes,
        funcname,
        has_keyword,
        is_arraylike,
        is_dataframe_like,
        is_index_like,
        is_integer,
        is_series_like,
        maybe_pluralize,
        ndeepmap,
        ndimlist,
        parse_bytes,
        typename,
    )
    from dask.widgets import get_template
    
    T_IntOrNaN = Union[int, float]  # Should be Union[int, Literal[np.nan]]
    
    DEFAULT_GET = named_schedulers.get("threads", named_schedulers["sync"])
    
    unknown_chunk_message = (
        "\n\n"
        "A possible solution: "
        "https://docs.dask.org/en/latest/array-chunks.html#unknown-chunks\n"
        "Summary: to compute chunks sizes, use\n\n"
        "   x.compute_chunk_sizes()  # for Dask Array `x`\n"
        "   ddf.to_dask_array(lengths=True)  # for Dask DataFrame `ddf`"
    )
    
    
    class PerformanceWarning(Warning):
        """A warning given when bad chunking may cause poor performance"""
    
    
    def getter(a, b, asarray=True, lock=None):
        if isinstance(b, tuple) and any(x is None for x in b):
            b2 = tuple(x for x in b if x is not None)
            b3 = tuple(
                None if x is None else slice(None, None)
                for x in b
                if not isinstance(x, Integral)
            )
            return getter(a, b2, asarray=asarray, lock=lock)[b3]
    
        if lock:
            lock.acquire()
        try:
            c = a[b]
            # Below we special-case `np.matrix` to force a conversion to
            # `np.ndarray` and preserve original Dask behavior for `getter`,
            # as for all purposes `np.matrix` is array-like and thus
            # `is_arraylike` evaluates to `True` in that case.
            if asarray and (not is_arraylike(c) or isinstance(c, np.matrix)):
                c = np.asarray(c)
        finally:
            if lock:
                lock.release()
        return c
    
    
    def getter_nofancy(a, b, asarray=True, lock=None):
        """A simple wrapper around ``getter``.
    
        Used to indicate to the optimization passes that the backend doesn't
        support fancy indexing.
        """
        return getter(a, b, asarray=asarray, lock=lock)
    
    
    def getter_inline(a, b, asarray=True, lock=None):
        """A getter function that optimizations feel comfortable inlining
    
        Slicing operations with this function may be inlined into a graph, such as
        in the following rewrite
    
        **Before**
    
        >>> a = x[:10]  # doctest: +SKIP
        >>> b = a + 1  # doctest: +SKIP
        >>> c = a * 2  # doctest: +SKIP
    
        **After**
    
        >>> b = x[:10] + 1  # doctest: +SKIP
        >>> c = x[:10] * 2  # doctest: +SKIP
    
        This inlining can be relevant to operations when running off of disk.
        """
        return getter(a, b, asarray=asarray, lock=lock)
    
    
    from dask.array.optimization import fuse_slice, optimize
    
    # __array_function__ dict for mapping aliases and mismatching names
    _HANDLED_FUNCTIONS = {}
    
    
    def implements(*numpy_functions):
        """Register an __array_function__ implementation for dask.array.Array
    
        Register that a function implements the API of a NumPy function (or several
        NumPy functions in case of aliases) which is handled with
        ``__array_function__``.
    
        Parameters
        ----------
        \\*numpy_functions : callables
            One or more NumPy functions that are handled by ``__array_function__``
            and will be mapped by `implements` to a `dask.array` function.
        """
    
        def decorator(dask_func):
            for numpy_function in numpy_functions:
                _HANDLED_FUNCTIONS[numpy_function] = dask_func
    
            return dask_func
    
        return decorator
    
    
    def _should_delegate(self, other) -> bool:
        """Check whether Dask should delegate to the other.
        This implementation follows NEP-13:
        https://numpy.org/neps/nep-0013-ufunc-overrides.html#behavior-in-combination-with-python-s-binary-operations
        """
        if hasattr(other, "__array_ufunc__") and other.__array_ufunc__ is None:
            return True
        elif (
            hasattr(other, "__array_ufunc__")
            and not is_valid_array_chunk(other)
            # don't delegate to our own parent classes
            and not isinstance(self, type(other))
            and type(self) is not type(other)
        ):
            return True
        return False
    
    
    def check_if_handled_given_other(f):
        """Check if method is handled by Dask given type of other
    
        Ensures proper deferral to upcast types in dunder operations without
        assuming unknown types are automatically downcast types.
        """
    
        @wraps(f)
        def wrapper(self, other):
            if _should_delegate(self, other):
                return NotImplemented
            else:
                return f(self, other)
    
        return wrapper
    
    
    def slices_from_chunks(chunks):
        """Translate chunks tuple to a set of slices in product order
    
        >>> slices_from_chunks(((2, 2), (3, 3, 3)))  # doctest: +NORMALIZE_WHITESPACE
         [(slice(0, 2, None), slice(0, 3, None)),
          (slice(0, 2, None), slice(3, 6, None)),
          (slice(0, 2, None), slice(6, 9, None)),
          (slice(2, 4, None), slice(0, 3, None)),
          (slice(2, 4, None), slice(3, 6, None)),
          (slice(2, 4, None), slice(6, 9, None))]
        """
        cumdims = [cached_cumsum(bds, initial_zero=True) for bds in chunks]
        slices = [
            [slice(s, s + dim) for s, dim in zip(starts, shapes)]
            for starts, shapes in zip(cumdims, chunks)
        ]
        return list(product(*slices))
    
    
    def graph_from_arraylike(
        arr,  # Any array-like which supports slicing
        chunks,
        shape,
        name,
        getitem=getter,
        lock=False,
        asarray=True,
        dtype=None,
        inline_array=False,
    ) -> HighLevelGraph:
        """
        HighLevelGraph for slicing chunks from an array-like according to a chunk pattern.
    
        If ``inline_array`` is True, this make a Blockwise layer of slicing tasks where the
        array-like is embedded into every task.,
    
        If ``inline_array`` is False, this inserts the array-like as a standalone value in
        a MaterializedLayer, then generates a Blockwise layer of slicing tasks that refer
        to it.
    
        >>> dict(graph_from_arraylike(arr, chunks=(2, 3), shape=(4, 6), name="X", inline_array=True))  # doctest: +SKIP
        {(arr, 0, 0): (getter, arr, (slice(0, 2), slice(0, 3))),
         (arr, 1, 0): (getter, arr, (slice(2, 4), slice(0, 3))),
         (arr, 1, 1): (getter, arr, (slice(2, 4), slice(3, 6))),
         (arr, 0, 1): (getter, arr, (slice(0, 2), slice(3, 6)))}
    
        >>> dict(  # doctest: +SKIP
                graph_from_arraylike(arr, chunks=((2, 2), (3, 3)), shape=(4,6), name="X", inline_array=False)
            )
        {"original-X": arr,
         ('X', 0, 0): (getter, 'original-X', (slice(0, 2), slice(0, 3))),
         ('X', 1, 0): (getter, 'original-X', (slice(2, 4), slice(0, 3))),
         ('X', 1, 1): (getter, 'original-X', (slice(2, 4), slice(3, 6))),
         ('X', 0, 1): (getter, 'original-X', (slice(0, 2), slice(3, 6)))}
        """
        chunks = normalize_chunks(chunks, shape, dtype=dtype)
        out_ind = tuple(range(len(shape)))
    
        if (
            has_keyword(getitem, "asarray")
            and has_keyword(getitem, "lock")
            and (not asarray or lock)
        ):
            kwargs = {"asarray": asarray, "lock": lock}
        else:
            # Common case, drop extra parameters
            kwargs = {}
    
        if inline_array:
            layer = core_blockwise(
                getitem,
                name,
                out_ind,
                arr,
                None,
                ArraySliceDep(chunks),
                out_ind,
                numblocks={},
                **kwargs,
            )
            return HighLevelGraph.from_collections(name, layer)
        else:
            original_name = "original-" + name
    
            layers = {}
            layers[original_name] = MaterializedLayer({original_name: arr})
            layers[name] = core_blockwise(
                getitem,
                name,
                out_ind,
                original_name,
                None,
                ArraySliceDep(chunks),
                out_ind,
                numblocks={},
                **kwargs,
            )
    
            deps = {
                original_name: set(),
                name: {original_name},
            }
            return HighLevelGraph(layers, deps)
    
    
    def dotmany(A, B, leftfunc=None, rightfunc=None, **kwargs):
        """Dot product of many aligned chunks
    
        >>> x = np.array([[1, 2], [1, 2]])
        >>> y = np.array([[10, 20], [10, 20]])
        >>> dotmany([x, x, x], [y, y, y])
        array([[ 90, 180],
               [ 90, 180]])
    
        Optionally pass in functions to apply to the left and right chunks
    
        >>> dotmany([x, x, x], [y, y, y], rightfunc=np.transpose)
        array([[150, 150],
               [150, 150]])
        """
        if leftfunc:
            A = map(leftfunc, A)
        if rightfunc:
            B = map(rightfunc, B)
        return sum(map(partial(np.dot, **kwargs), A, B))
    
    
    def _concatenate2(arrays, axes=None):
        """Recursively concatenate nested lists of arrays along axes
    
        Each entry in axes corresponds to each level of the nested list.  The
        length of axes should correspond to the level of nesting of arrays.
        If axes is an empty list or tuple, return arrays, or arrays[0] if
        arrays is a list.
    
        >>> x = np.array([[1, 2], [3, 4]])
        >>> _concatenate2([x, x], axes=[0])
        array([[1, 2],
               [3, 4],
               [1, 2],
               [3, 4]])
    
        >>> _concatenate2([x, x], axes=[1])
        array([[1, 2, 1, 2],
               [3, 4, 3, 4]])
    
        >>> _concatenate2([[x, x], [x, x]], axes=[0, 1])
        array([[1, 2, 1, 2],
               [3, 4, 3, 4],
               [1, 2, 1, 2],
               [3, 4, 3, 4]])
    
        Supports Iterators
        >>> _concatenate2(iter([x, x]), axes=[1])
        array([[1, 2, 1, 2],
               [3, 4, 3, 4]])
    
        Special Case
        >>> _concatenate2([x, x], axes=())
        array([[1, 2],
               [3, 4]])
        """
        if axes is None:
            axes = []
    
        if axes == ():
            if isinstance(arrays, list):
                return arrays[0]
            else:
                return arrays
    
        if isinstance(arrays, Iterator):
            arrays = list(arrays)
        if not isinstance(arrays, (list, tuple)):
            return arrays
        if len(axes) > 1:
            arrays = [_concatenate2(a, axes=axes[1:]) for a in arrays]
        concatenate = concatenate_lookup.dispatch(
            type(max(arrays, key=lambda x: getattr(x, "__array_priority__", 0)))
        )
        if isinstance(arrays[0], dict):
            # Handle concatenation of `dict`s, used as a replacement for structured
            # arrays when that's not supported by the array library (e.g., CuPy).
            keys = list(arrays[0].keys())
            assert all(list(a.keys()) == keys for a in arrays)
            ret = dict()
            for k in keys:
                ret[k] = concatenate(list(a[k] for a in arrays), axis=axes[0])
            return ret
        else:
            return concatenate(arrays, axis=axes[0])
    
    
    def apply_infer_dtype(func, args, kwargs, funcname, suggest_dtype="dtype", nout=None):
        """
        Tries to infer output dtype of ``func`` for a small set of input arguments.
    
        Parameters
        ----------
        func: Callable
            Function for which output dtype is to be determined
    
        args: List of array like
            Arguments to the function, which would usually be used. Only attributes
            ``ndim`` and ``dtype`` are used.
    
        kwargs: dict
            Additional ``kwargs`` to the ``func``
    
        funcname: String
            Name of calling function to improve potential error messages
    
        suggest_dtype: None/False or String
            If not ``None`` adds suggestion to potential error message to specify a dtype
            via the specified kwarg. Defaults to ``'dtype'``.
    
        nout: None or Int
            ``None`` if function returns single output, integer if many.
            Defaults to ``None``.
    
        Returns
        -------
        : dtype or List of dtype
            One or many dtypes (depending on ``nout``)
        """
        from dask.array.utils import meta_from_array
    
        # make sure that every arg is an evaluated array
        args = [
            (
                np.ones_like(meta_from_array(x), shape=((1,) * x.ndim), dtype=x.dtype)
                if is_arraylike(x)
                else x
            )
            for x in args
        ]
        try:
            with np.errstate(all="ignore"):
                o = func(*args, **kwargs)
        except Exception as e:
            exc_type, exc_value, exc_traceback = sys.exc_info()
            tb = "".join(traceback.format_tb(exc_traceback))
            suggest = (
                (
                    "Please specify the dtype explicitly using the "
                    "`{dtype}` kwarg.\n\n".format(dtype=suggest_dtype)
                )
                if suggest_dtype
                else ""
            )
            msg = (
                f"`dtype` inference failed in `{funcname}`.\n\n"
                f"{suggest}"
                "Original error is below:\n"
                "------------------------\n"
                f"{e!r}\n\n"
                "Traceback:\n"
                "---------\n"
                f"{tb}"
            )
        else:
            msg = None
        if msg is not None:
            raise ValueError(msg)
        return getattr(o, "dtype", type(o)) if nout is None else tuple(e.dtype for e in o)
    
    
    def normalize_arg(x):
        """Normalize user provided arguments to blockwise or map_blocks
    
        We do a few things:
    
        1.  If they are string literals that might collide with blockwise_token then we
            quote them
        2.  IF they are large (as defined by sizeof) then we put them into the
            graph on their own by using dask.delayed
        """
        if is_dask_collection(x):
            return x
        elif isinstance(x, str) and re.match(r"_\d+", x):
            return delayed(x)
        elif isinstance(x, list) and len(x) >= 10:
            return delayed(x)
        elif sizeof(x) > 1e6:
            return delayed(x)
        else:
            return x
    
    
    def _pass_extra_kwargs(func, keys, *args, **kwargs):
        """Helper for :func:`dask.array.map_blocks` to pass `block_info` or `block_id`.
    
        For each element of `keys`, a corresponding element of args is changed
        to a keyword argument with that key, before all arguments re passed on
        to `func`.
        """
        kwargs.update(zip(keys, args))
        return func(*args[len(keys) :], **kwargs)
    
    
    def map_blocks(
        func,
        *args,
        name=None,
        token=None,
        dtype=None,
        chunks=None,
        drop_axis=None,
        new_axis=None,
        enforce_ndim=False,
        meta=None,
        **kwargs,
    ):
        """Map a function across all blocks of a dask array.
    
        Note that ``map_blocks`` will attempt to automatically determine the output
        array type by calling ``func`` on 0-d versions of the inputs. Please refer to
        the ``meta`` keyword argument below if you expect that the function will not
        succeed when operating on 0-d arrays.
    
        Parameters
        ----------
        func : callable
            Function to apply to every block in the array.
            If ``func`` accepts ``block_info=`` or ``block_id=``
            as keyword arguments, these will be passed dictionaries
            containing information about input and output chunks/arrays
            during computation. See examples for details.
        args : dask arrays or other objects
        dtype : np.dtype, optional
            The ``dtype`` of the output array. It is recommended to provide this.
            If not provided, will be inferred by applying the function to a small
            set of fake data.
        chunks : tuple, optional
            Chunk shape of resulting blocks if the function does not preserve
            shape. If not provided, the resulting array is assumed to have the same
            block structure as the first input array.
        drop_axis : number or iterable, optional
            Dimensions lost by the function.
        new_axis : number or iterable, optional
            New dimensions created by the function. Note that these are applied
            after ``drop_axis`` (if present). The size of each chunk along this
            dimension will be set to 1. Please specify ``chunks`` if the individual
            chunks have a different size.
        enforce_ndim : bool, default False
            Whether to enforce at runtime that the dimensionality of the array
            produced by ``func`` actually matches that of the array returned by
            ``map_blocks``.
            If True, this will raise an error when there is a mismatch.
        token : string, optional
            The key prefix to use for the output array. If not provided, will be
            determined from the function name.
        name : string, optional
            The key name to use for the output array. Note that this fully
            specifies the output key name, and must be unique. If not provided,
            will be determined by a hash of the arguments.
        meta : array-like, optional
            The ``meta`` of the output array, when specified is expected to be an
            array of the same type and dtype of that returned when calling ``.compute()``
            on the array returned by this function. When not provided, ``meta`` will be
            inferred by applying the function to a small set of fake data, usually a
            0-d array. It's important to ensure that ``func`` can successfully complete
            computation without raising exceptions when 0-d is passed to it, providing
            ``meta`` will be required otherwise. If the output type is known beforehand
            (e.g., ``np.ndarray``, ``cupy.ndarray``), an empty array of such type dtype
            can be passed, for example: ``meta=np.array((), dtype=np.int32)``.
        **kwargs :
            Other keyword arguments to pass to function. Values must be constants
            (not dask.arrays)
    
        See Also
        --------
        dask.array.map_overlap : Generalized operation with overlap between neighbors.
        dask.array.blockwise : Generalized operation with control over block alignment.
    
        Examples
        --------
        >>> import dask.array as da
        >>> x = da.arange(6, chunks=3)
    
        >>> x.map_blocks(lambda x: x * 2).compute()
        array([ 0,  2,  4,  6,  8, 10])
    
        The ``da.map_blocks`` function can also accept multiple arrays.
    
        >>> d = da.arange(5, chunks=2)
        >>> e = da.arange(5, chunks=2)
    
        >>> f = da.map_blocks(lambda a, b: a + b**2, d, e)
        >>> f.compute()
        array([ 0,  2,  6, 12, 20])
    
        If the function changes shape of the blocks then you must provide chunks
        explicitly.
    
        >>> y = x.map_blocks(lambda x: x[::2], chunks=((2, 2),))
    
        You have a bit of freedom in specifying chunks.  If all of the output chunk
        sizes are the same, you can provide just that chunk size as a single tuple.
    
        >>> a = da.arange(18, chunks=(6,))
        >>> b = a.map_blocks(lambda x: x[:3], chunks=(3,))
    
        If the function changes the dimension of the blocks you must specify the
        created or destroyed dimensions.
    
        >>> b = a.map_blocks(lambda x: x[None, :, None], chunks=(1, 6, 1),
        ...                  new_axis=[0, 2])
    
        If ``chunks`` is specified but ``new_axis`` is not, then it is inferred to
        add the necessary number of axes on the left.
    
        Note that ``map_blocks()`` will concatenate chunks along axes specified by
        the keyword parameter ``drop_axis`` prior to applying the function.
        This is illustrated in the figure below:
    
        .. image:: /images/map_blocks_drop_axis.png
    
        Due to memory-size-constraints, it is often not advisable to use ``drop_axis``
        on an axis that is chunked.  In that case, it is better not to use
        ``map_blocks`` but rather
        ``dask.array.reduction(..., axis=dropped_axes, concatenate=False)`` which
        maintains a leaner memory footprint while it drops any axis.
    
        Map_blocks aligns blocks by block positions without regard to shape. In the
        following example we have two arrays with the same number of blocks but
        with different shape and chunk sizes.
    
        >>> x = da.arange(1000, chunks=(100,))
        >>> y = da.arange(100, chunks=(10,))
    
        The relevant attribute to match is numblocks.
    
        >>> x.numblocks
        (10,)
        >>> y.numblocks
        (10,)
    
        If these match (up to broadcasting rules) then we can map arbitrary
        functions across blocks
    
        >>> def func(a, b):
        ...     return np.array([a.max(), b.max()])
    
        >>> da.map_blocks(func, x, y, chunks=(2,), dtype='i8')
        dask.array<func, shape=(20,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray>
    
        >>> _.compute()
        array([ 99,   9, 199,  19, 299,  29, 399,  39, 499,  49, 599,  59, 699,
                69, 799,  79, 899,  89, 999,  99])
    
        Your block function can get information about where it is in the array by
        accepting a special ``block_info`` or ``block_id`` keyword argument.
        During computation, they will contain information about each of the input
        and output chunks (and dask arrays) relevant to each call of ``func``.
    
        >>> def func(block_info=None):
        ...     pass
    
        This will receive the following information:
    
        >>> block_info  # doctest: +SKIP
        {0: {'shape': (1000,),
             'num-chunks': (10,),
             'chunk-location': (4,),
             'array-location': [(400, 500)]},
         None: {'shape': (1000,),
                'num-chunks': (10,),
                'chunk-location': (4,),
                'array-location': [(400, 500)],
                'chunk-shape': (100,),
                'dtype': dtype('float64')}}
    
        The keys to the ``block_info`` dictionary indicate which is the input and
        output Dask array:
    
        - **Input Dask array(s):** ``block_info[0]`` refers to the first input Dask array.
          The dictionary key is ``0`` because that is the argument index corresponding
          to the first input Dask array.
          In cases where multiple Dask arrays have been passed as input to the function,
          you can access them with the number corresponding to the input argument,
          eg: ``block_info[1]``, ``block_info[2]``, etc.
          (Note that if you pass multiple Dask arrays as input to map_blocks,
          the arrays must match each other by having matching numbers of chunks,
          along corresponding dimensions up to broadcasting rules.)
        - **Output Dask array:** ``block_info[None]`` refers to the output Dask array,
          and contains information about the output chunks.
          The output chunk shape and dtype may may be different than the input chunks.
    
        For each dask array, ``block_info`` describes:
    
        - ``shape``: the shape of the full Dask array,
        - ``num-chunks``: the number of chunks of the full array in each dimension,
        - ``chunk-location``: the chunk location (for example the fourth chunk over
          in the first dimension), and
        - ``array-location``: the array location within the full Dask array
          (for example the slice corresponding to ``40:50``).
    
        In addition to these, there are two extra parameters described by
        ``block_info`` for the output array (in ``block_info[None]``):
    
        - ``chunk-shape``: the output chunk shape, and
        - ``dtype``: the output dtype.
    
        These features can be combined to synthesize an array from scratch, for
        example:
    
        >>> def func(block_info=None):
        ...     loc = block_info[None]['array-location'][0]
        ...     return np.arange(loc[0], loc[1])
    
        >>> da.map_blocks(func, chunks=((4, 4),), dtype=np.float64)
        dask.array<func, shape=(8,), dtype=float64, chunksize=(4,), chunktype=numpy.ndarray>
    
        >>> _.compute()
        array([0, 1, 2, 3, 4, 5, 6, 7])
    
        ``block_id`` is similar to ``block_info`` but contains only the ``chunk_location``:
    
        >>> def func(block_id=None):
        ...     pass
    
        This will receive the following information:
    
        >>> block_id  # doctest: +SKIP
        (4, 3)
    
        You may specify the key name prefix of the resulting task in the graph with
        the optional ``token`` keyword argument.
    
        >>> x.map_blocks(lambda x: x + 1, name='increment')
        dask.array<increment, shape=(1000,), dtype=int64, chunksize=(100,), chunktype=numpy.ndarray>
    
        For functions that may not handle 0-d arrays, it's also possible to specify
        ``meta`` with an empty array matching the type of the expected result. In
        the example below, ``func`` will result in an ``IndexError`` when computing
        ``meta``:
    
        >>> rng = da.random.default_rng()
        >>> da.map_blocks(lambda x: x[2], rng.random(5), meta=np.array(()))
        dask.array<lambda, shape=(5,), dtype=float64, chunksize=(5,), chunktype=numpy.ndarray>
    
        Similarly, it's possible to specify a non-NumPy array to ``meta``, and provide
        a ``dtype``:
    
        >>> import cupy  # doctest: +SKIP
        >>> rng = da.random.default_rng(cupy.random.default_rng())  # doctest: +SKIP
        >>> dt = np.float32
        >>> da.map_blocks(lambda x: x[2], rng.random(5, dtype=dt), meta=cupy.array((), dtype=dt))  # doctest: +SKIP
        dask.array<lambda, shape=(5,), dtype=float32, chunksize=(5,), chunktype=cupy.ndarray>
        """
        if drop_axis is None:
            drop_axis = []
    
        if not callable(func):
            msg = (
                "First argument must be callable function, not %s\n"
                "Usage:   da.map_blocks(function, x)\n"
                "   or:   da.map_blocks(function, x, y, z)"
            )
            raise TypeError(msg % type(func).__name__)
        if token:
            warnings.warn(
                "The `token=` keyword to `map_blocks` has been moved to `name=`. "
                "Please use `name=` instead as the `token=` keyword will be removed "
                "in a future release.",
                category=FutureWarning,
            )
            name = token
    
        name = f"{name or funcname(func)}-{tokenize(func, dtype, chunks, drop_axis, new_axis, *args, **kwargs)}"
        new_axes = {}
    
        if isinstance(drop_axis, Number):
            drop_axis = [drop_axis]
        if isinstanc…tack
        """
        from dask.array import wrap
    
        seq = [asarray(a, allow_unknown_chunksizes=allow_unknown_chunksizes) for a in seq]
    
        if not seq:
            raise ValueError("Need array(s) to concatenate")
    
        if axis is None:
            seq = [a.flatten() for a in seq]
            axis = 0
    
        seq_metas = [meta_from_array(s) for s in seq]
        _concatenate = concatenate_lookup.dispatch(
            type(max(seq_metas, key=lambda x: getattr(x, "__array_priority__", 0)))
        )
        meta = _concatenate(seq_metas, axis=axis)
    
        # Promote types to match meta
        seq = [a.astype(meta.dtype) for a in seq]
    
        # Find output array shape
        ndim = len(seq[0].shape)
        shape = tuple(
            sum(a.shape[i] for a in seq) if i == axis else seq[0].shape[i]
            for i in range(ndim)
        )
    
        # Drop empty arrays
        seq2 = [a for a in seq if a.size]
        if not seq2:
            seq2 = seq
    
        if axis < 0:
            axis = ndim + axis
        if axis >= ndim:
            msg = (
                "Axis must be less than than number of dimensions"
                "\nData has %d dimensions, but got axis=%d"
            )
            raise ValueError(msg % (ndim, axis))
    
        n = len(seq2)
        if n == 0:
            try:
                return wrap.empty_like(meta, shape=shape, chunks=shape, dtype=meta.dtype)
            except TypeError:
                return wrap.empty(shape, chunks=shape, dtype=meta.dtype)
        elif n == 1:
            return seq2[0]
    
        if not allow_unknown_chunksizes and not all(
            i == axis or all(x.shape[i] == seq2[0].shape[i] for x in seq2)
            for i in range(ndim)
        ):
            if any(map(np.isnan, seq2[0].shape)):
                raise ValueError(
                    "Tried to concatenate arrays with unknown"
                    " shape %s.\n\nTwo solutions:\n"
                    "  1. Force concatenation pass"
                    " allow_unknown_chunksizes=True.\n"
                    "  2. Compute shapes with "
                    "[x.compute_chunk_sizes() for x in seq]" % str(seq2[0].shape)
                )
            raise ValueError("Shapes do not align: %s", [x.shape for x in seq2])
    
        inds = [list(range(ndim)) for i in range(n)]
        for i, ind in enumerate(inds):
            ind[axis] = -(i + 1)
    
        uc_args = list(concat(zip(seq2, inds)))
        _, seq2 = unify_chunks(*uc_args, warn=False)
    
        bds = [a.chunks for a in seq2]
    
        chunks = (
            seq2[0].chunks[:axis]
            + (sum((bd[axis] for bd in bds), ()),)
            + seq2[0].chunks[axis + 1 :]
        )
    
        cum_dims = [0] + list(accumulate(add, [len(a.chunks[axis]) for a in seq2]))
    
        names = [a.name for a in seq2]
    
        name = "concatenate-" + tokenize(names, axis)
        keys = list(product([name], *[range(len(bd)) for bd in chunks]))
    
        values = [
            (names[bisect(cum_dims, key[axis + 1]) - 1],)
            + key[1 : axis + 1]
            + (key[axis + 1] - cum_dims[bisect(cum_dims, key[axis + 1]) - 1],)
            + key[axis + 2 :]
            for key in keys
        ]
    
        dsk = dict(zip(keys, values))
        graph = HighLevelGraph.from_collections(name, dsk, dependencies=seq2)
    
        return Array(graph, name, chunks, meta=meta)
    
    
    def load_store_chunk(
        x: Any,
        out: Any,
        index: slice,
        lock: Any,
        return_stored: bool,
        load_stored: bool,
    ):
        """
        A function inserted in a Dask graph for storing a chunk.
    
        Parameters
        ----------
        x: array-like
            An array (potentially a NumPy one)
        out: array-like
            Where to store results.
        index: slice-like
            Where to store result from ``x`` in ``out``.
        lock: Lock-like or False
            Lock to use before writing to ``out``.
        return_stored: bool
            Whether to return ``out``.
        load_stored: bool
            Whether to return the array stored in ``out``.
            Ignored if ``return_stored`` is not ``True``.
    
        Returns
        -------
    
        If return_stored=True and load_stored=False
            out
        If return_stored=True and load_stored=True
            out[index]
        If return_stored=False and compute=False
            None
    
        Examples
        --------
    
        >>> a = np.ones((5, 6))
        >>> b = np.empty(a.shape)
        >>> load_store_chunk(a, b, (slice(None), slice(None)), False, False, False)
        """
        if lock:
            lock.acquire()
        try:
            if x is not None and x.size != 0:
                if is_arraylike(x):
                    out[index] = x
                else:
                    out[index] = np.asanyarray(x)
    
            if return_stored and load_stored:
                return out[index]
            elif return_stored and not load_stored:
                return out
            else:
                return None
        finally:
            if lock:
                lock.release()
    
    
    def store_chunk(
        x: ArrayLike, out: ArrayLike, index: slice, lock: Any, return_stored: bool
    ):
        return load_store_chunk(x, out, index, lock, return_stored, False)
    
    
    A = TypeVar("A", bound=ArrayLike)
    
    
    def load_chunk(out: A, index: slice, lock: Any) -> A:
        return load_store_chunk(None, out, index, lock, True, True)
    
    
    def insert_to_ooc(
        keys: list,
        chunks: tuple[tuple[int, ...], ...],
        out: ArrayLike,
        name: str,
        *,
        lock: Lock | bool = True,
        region: tuple[slice, ...] | slice | None = None,
        return_stored: bool = False,
        load_stored: bool = False,
    ) -> dict:
        """
        Creates a Dask graph for storing chunks from ``arr`` in ``out``.
    
        Parameters
        ----------
        keys: list
            Dask keys of the input array
        chunks: tuple
            Dask chunks of the input array
        out: array-like
            Where to store results to
        name: str
            First element of dask keys
        lock: Lock-like or bool, optional
            Whether to lock or with what (default is ``True``,
            which means a :class:`threading.Lock` instance).
        region: slice-like, optional
            Where in ``out`` to store ``arr``'s results
            (default is ``None``, meaning all of ``out``).
        return_stored: bool, optional
            Whether to return ``out``
            (default is ``False``, meaning ``None`` is returned).
        load_stored: bool, optional
            Whether to handling loading from ``out`` at the same time.
            Ignored if ``return_stored`` is not ``True``.
            (default is ``False``, meaning defer to ``return_stored``).
    
        Returns
        -------
        dask graph of store operation
    
        Examples
        --------
        >>> import dask.array as da
        >>> d = da.ones((5, 6), chunks=(2, 3))
        >>> a = np.empty(d.shape)
        >>> insert_to_ooc(d.__dask_keys__(), d.chunks, a, "store-123")  # doctest: +SKIP
        """
    
        if lock is True:
            lock = Lock()
    
        slices = slices_from_chunks(chunks)
        if region:
            slices = [fuse_slice(region, slc) for slc in slices]
    
        if return_stored and load_stored:
            func = load_store_chunk
            args = (load_stored,)
        else:
            func = store_chunk  # type: ignore
            args = ()  # type: ignore
    
        dsk = {
            (name,) + t[1:]: (func, t, out, slc, lock, return_stored) + args
            for t, slc in zip(core.flatten(keys), slices)
        }
        return dsk
    
    
    def retrieve_from_ooc(
        keys: Collection[Key], dsk_pre: Graph, dsk_post: Graph
    ) -> dict[tuple, Any]:
        """
        Creates a Dask graph for loading stored ``keys`` from ``dsk``.
    
        Parameters
        ----------
        keys: Collection
            A sequence containing Dask graph keys to load
        dsk_pre: Mapping
            A Dask graph corresponding to a Dask Array before computation
        dsk_post: Mapping
            A Dask graph corresponding to a Dask Array after computation
    
        Examples
        --------
        >>> import dask.array as da
        >>> d = da.ones((5, 6), chunks=(2, 3))
        >>> a = np.empty(d.shape)
        >>> g = insert_to_ooc(d.__dask_keys__(), d.chunks, a, "store-123")
        >>> retrieve_from_ooc(g.keys(), g, {k: k for k in g.keys()})  # doctest: +SKIP
        """
        load_dsk = {
            ("load-" + k[0],) + k[1:]: (load_chunk, dsk_post[k]) + dsk_pre[k][3:-1]  # type: ignore
            for k in keys
        }
    
        return load_dsk
    
    
    def _as_dtype(a, dtype):
        if dtype is None:
            return a
        else:
            return a.astype(dtype)
    
    
    def asarray(
        a, allow_unknown_chunksizes=False, dtype=None, order=None, *, like=None, **kwargs
    ):
        """Convert the input to a dask array.
    
        Parameters
        ----------
        a : array-like
            Input data, in any form that can be converted to a dask array. This
            includes lists, lists of tuples, tuples, tuples of tuples, tuples of
            lists and ndarrays.
        allow_unknown_chunksizes: bool
            Allow unknown chunksizes, such as come from converting from dask
            dataframes.  Dask.array is unable to verify that chunks line up.  If
            data comes from differently aligned sources then this can cause
            unexpected results.
        dtype : data-type, optional
            By default, the data-type is inferred from the input data.
        order : {‘C’, ‘F’, ‘A’, ‘K’}, optional
            Memory layout. ‘A’ and ‘K’ depend on the order of input array a.
            ‘C’ row-major (C-style), ‘F’ column-major (Fortran-style) memory
            representation. ‘A’ (any) means ‘F’ if a is Fortran contiguous, ‘C’
            otherwise ‘K’ (keep) preserve input order. Defaults to ‘C’.
        like: array-like
            Reference object to allow the creation of Dask arrays with chunks
            that are not NumPy arrays. If an array-like passed in as ``like``
            supports the ``__array_function__`` protocol, the chunk type of the
            resulting array will be defined by it. In this case, it ensures the
            creation of a Dask array compatible with that passed in via this
            argument. If ``like`` is a Dask array, the chunk type of the
            resulting array will be defined by the chunk type of ``like``.
            Requires NumPy 1.20.0 or higher.
    
        Returns
        -------
        out : dask array
            Dask array interpretation of a.
    
        Examples
        --------
        >>> import dask.array as da
        >>> import numpy as np
        >>> x = np.arange(3)
        >>> da.asarray(x)
        dask.array<array, shape=(3,), dtype=int64, chunksize=(3,), chunktype=numpy.ndarray>
    
        >>> y = [[1, 2, 3], [4, 5, 6]]
        >>> da.asarray(y)
        dask.array<array, shape=(2, 3), dtype=int64, chunksize=(2, 3), chunktype=numpy.ndarray>
    
        .. warning::
            `order` is ignored if `a` is an `Array`, has the attribute ``to_dask_array``,
            or is a list or tuple of `Array`'s.
        """
        if like is None:
            if isinstance(a, Array):
                return _as_dtype(a, dtype)
            elif hasattr(a, "to_dask_array"):
                return _as_dtype(a.to_dask_array(), dtype)
            elif type(a).__module__.split(".")[0] == "xarray" and hasattr(a, "data"):
                return _as_dtype(asarray(a.data, order=order), dtype)
            elif isinstance(a, (list, tuple)) and any(isinstance(i, Array) for i in a):
                return _as_dtype(
                    stack(a, allow_unknown_chunksizes=allow_unknown_chunksizes), dtype
                )
            elif not isinstance(getattr(a, "shape", None), Iterable):
                a = np.asarray(a, dtype=dtype, order=order)
        else:
            like_meta = meta_from_array(like)
            if isinstance(a, Array):
                return a.map_blocks(np.asarray, like=like_meta, dtype=dtype, order=order)
            else:
                a = np.asarray(a, like=like_meta, dtype=dtype, order=order)
        return from_array(a, getitem=getter_inline, **kwargs)
    
    
    def asanyarray(a, dtype=None, order=None, *, like=None, inline_array=False):
        """Convert the input to a dask array.
    
        Subclasses of ``np.ndarray`` will be passed through as chunks unchanged.
    
        Parameters
        ----------
        a : array-like
            Input data, in any form that can be converted to a dask array. This
            includes lists, lists of tuples, tuples, tuples of tuples, tuples of
            lists and ndarrays.
        dtype : data-type, optional
            By default, the data-type is inferred from the input data.
        order : {‘C’, ‘F’, ‘A’, ‘K’}, optional
            Memory layout. ‘A’ and ‘K’ depend on the order of input array a.
            ‘C’ row-major (C-style), ‘F’ column-major (Fortran-style) memory
            representation. ‘A’ (any) means ‘F’ if a is Fortran contiguous, ‘C’
            otherwise ‘K’ (keep) preserve input order. Defaults to ‘C’.
        like: array-like
            Reference object to allow the creation of Dask arrays with chunks
            that are not NumPy arrays. If an array-like passed in as ``like``
            supports the ``__array_function__`` protocol, the chunk type of the
            resulting array will be defined by it. In this case, it ensures the
            creation of a Dask array compatible with that passed in via this
            argument. If ``like`` is a Dask array, the chunk type of the
            resulting array will be defined by the chunk type of ``like``.
            Requires NumPy 1.20.0 or higher.
        inline_array:
            Whether to inline the array in the resulting dask graph. For more information,
            see the documentation for ``dask.array.from_array()``.
    
        Returns
        -------
        out : dask array
            Dask array interpretation of a.
    
        Examples
        --------
        >>> import dask.array as da
        >>> import numpy as np
        >>> x = np.arange(3)
        >>> da.asanyarray(x)
        dask.array<array, shape=(3,), dtype=int64, chunksize=(3,), chunktype=numpy.ndarray>
    
        >>> y = [[1, 2, 3], [4, 5, 6]]
        >>> da.asanyarray(y)
        dask.array<array, shape=(2, 3), dtype=int64, chunksize=(2, 3), chunktype=numpy.ndarray>
    
        .. warning::
            `order` is ignored if `a` is an `Array`, has the attribute ``to_dask_array``,
            or is a list or tuple of `Array`'s.
        """
        if like is None:
            if isinstance(a, Array):
                return _as_dtype(a, dtype)
            elif hasattr(a, "to_dask_array"):
                return _as_dtype(a.to_dask_array(), dtype)
            elif type(a).__module__.split(".")[0] == "xarray" and hasattr(a, "data"):
                return _as_dtype(asarray(a.data, order=order), dtype)
            elif isinstance(a, (list, tuple)) and any(isinstance(i, Array) for i in a):
                return _as_dtype(stack(a), dtype)
            elif not isinstance(getattr(a, "shape", None), Iterable):
                a = np.asanyarray(a, dtype=dtype, order=order)
        else:
            like_meta = meta_from_array(like)
            if isinstance(a, Array):
                return a.map_blocks(np.asanyarray, like=like_meta, dtype=dtype, order=order)
            else:
                a = np.asanyarray(a, like=like_meta, dtype=dtype, order=order)
        return from_array(
            a,
            chunks=a.shape,
            getitem=getter_inline,
            asarray=False,
            inline_array=inline_array,
        )
    
    
    def is_scalar_for_elemwise(arg):
        """
    
        >>> is_scalar_for_elemwise(42)
        True
        >>> is_scalar_for_elemwise('foo')
        True
        >>> is_scalar_for_elemwise(True)
        True
        >>> is_scalar_for_elemwise(np.array(42))
        True
        >>> is_scalar_for_elemwise([1, 2, 3])
        True
        >>> is_scalar_for_elemwise(np.array([1, 2, 3]))
        False
        >>> is_scalar_for_elemwise(from_array(np.array(0), chunks=()))
        False
        >>> is_scalar_for_elemwise(np.dtype('i4'))
        True
        """
        # the second half of shape_condition is essentially just to ensure that
        # dask series / frame are treated as scalars in elemwise.
        maybe_shape = getattr(arg, "shape", None)
        shape_condition = not isinstance(maybe_shape, Iterable) or any(
            is_dask_collection(x) for x in maybe_shape
        )
    
        return (
            np.isscalar(arg)
            or shape_condition
            or isinstance(arg, np.dtype)
            or (isinstance(arg, np.ndarray) and arg.ndim == 0)
        )
    
    
    def broadcast_shapes(*shapes):
        """
        Determines output shape from broadcasting arrays.
    
        Parameters
        ----------
        shapes : tuples
            The shapes of the arguments.
    
        Returns
        -------
        output_shape : tuple
    
        Raises
        ------
        ValueError
            If the input shapes cannot be successfully broadcast together.
        """
        if len(shapes) == 1:
            return shapes[0]
        out = []
        for sizes in zip_longest(*map(reversed, shapes), fillvalue=-1):
            if np.isnan(sizes).any():
                dim = np.nan
            else:
                dim = 0 if 0 in sizes else np.max(sizes).item()
            if any(i not in [-1, 0, 1, dim] and not np.isnan(i) for i in sizes):
                raise ValueError(
                    "operands could not be broadcast together with "
                    "shapes {}".format(" ".join(map(str, shapes)))
                )
            out.append(dim)
        return tuple(reversed(out))
    
    
    def elemwise(op, *args, out=None, where=True, dtype=None, name=None, **kwargs):
        """Apply an elementwise ufunc-like function blockwise across arguments.
    
        Like numpy ufuncs, broadcasting rules are respected.
    
        Parameters
        ----------
        op : callable
            The function to apply. Should be numpy ufunc-like in the parameters
            that it accepts.
        *args : Any
            Arguments to pass to `op`. Non-dask array-like objects are first
            converted to dask arrays, then all arrays are broadcast together before
            applying the function blockwise across all arguments. Any scalar
            arguments are passed as-is following normal numpy ufunc behavior.
        out : dask array, optional
            If out is a dask.array then this overwrites the contents of that array
            with the result.
        where : array_like, optional
            An optional boolean mask marking locations where the ufunc should be
            applied. Can be a scalar, dask array, or any other array-like object.
            Mirrors the ``where`` argument to numpy ufuncs, see e.g. ``numpy.add``
            for more information.
        dtype : dtype, optional
            If provided, overrides the output array dtype.
        name : str, optional
            A unique key name to use when building the backing dask graph. If not
            provided, one will be automatically generated based on the input
            arguments.
    
        Examples
        --------
        >>> elemwise(add, x, y)  # doctest: +SKIP
        >>> elemwise(sin, x)  # doctest: +SKIP
        >>> elemwise(sin, x, out=dask_array)  # doctest: +SKIP
    
        See Also
        --------
        blockwise
        """
        if kwargs:
            raise TypeError(
                f"{op.__name__} does not take the following keyword arguments "
                f"{sorted(kwargs)}"
            )
    
        out = _elemwise_normalize_out(out)
        where = _elemwise_normalize_where(where)
        args = [np.asarray(a) if isinstance(a, (list, tuple)) else a for a in args]
    
        shapes = []
        for arg in args:
            shape = getattr(arg, "shape", ())
            if any(is_dask_collection(x) for x in shape):
                # Want to exclude Delayed shapes and dd.Scalar
                shape = ()
            shapes.append(shape)
        if isinstance(where, Array):
            shapes.append(where.shape)
        if isinstance(out, Array):
            shapes.append(out.shape)
    
        shapes = [s if isinstance(s, Iterable) else () for s in shapes]
        out_ndim = len(
            broadcast_shapes(*shapes)
        )  # Raises ValueError if dimensions mismatch
        expr_inds = tuple(range(out_ndim))[::-1]
    
        if dtype is not None:
            need_enforce_dtype = True
        else:
            # We follow NumPy's rules for dtype promotion, which special cases
            # scalars and 0d ndarrays (which it considers equivalent) by using
            # their values to compute the result dtype:
            # https://github.com/numpy/numpy/issues/6240
            # We don't inspect the values of 0d dask arrays, because these could
            # hold potentially very expensive calculations. Instead, we treat
            # them just like other arrays, and if necessary cast the result of op
            # to match.
            vals = [
                (
                    np.empty((1,) * max(1, a.ndim), dtype=a.dtype)
                    if not is_scalar_for_elemwise(a)
                    else a
                )
                for a in args
            ]
            try:
                dtype = apply_infer_dtype(op, vals, {}, "elemwise", suggest_dtype=False)
            except Exception:
                return NotImplemented
            need_enforce_dtype = any(
                not is_scalar_for_elemwise(a) and a.ndim == 0 for a in args
            )
    
        if not name:
            name = f"{funcname(op)}-{tokenize(op, dtype, *args, where)}"
    
        blockwise_kwargs = dict(dtype=dtype, name=name, token=funcname(op).strip("_"))
    
        if where is not True:
            blockwise_kwargs["elemwise_where_function"] = op
            op = _elemwise_handle_where
            args.extend([where, out])
    
        if need_enforce_dtype:
            blockwise_kwargs["enforce_dtype"] = dtype
            blockwise_kwargs["enforce_dtype_function"] = op
            op = _enforce_dtype
    
        result = blockwise(
            op,
            expr_inds,
            *concat(
                (a, tuple(range(a.ndim)[::-1]) if not is_scalar_for_elemwise(a) else None)
                for a in args
            ),
            **blockwise_kwargs,
        )
    
        return handle_out(out, result)
    
    
    def _elemwise_normalize_where(where):
        if where is True:
            return True
        elif where is False or where is None:
            return False
        return asarray(where)
    
    
    def _elemwise_handle_where(*args, **kwargs):
        function = kwargs.pop("elemwise_where_function")
        *args, where, out = args
        if hasattr(out, "copy"):
            out = out.copy()
        return function(*args, where=where, out=out, **kwargs)
    
    
    def _elemwise_normalize_out(out):
        if isinstance(out, tuple):
            if len(out) == 1:
                out = out[0]
            elif len(out) > 1:
                raise NotImplementedError("The out parameter is not fully supported")
            else:
                out = None
        if not (out is None or isinstance(out, Array)):
            raise NotImplementedError(
                f"The out parameter is not fully supported."
                f" Received type {type(out).__name__}, expected Dask Array"
            )
        return out
    
    
    def handle_out(out, result):
        """Handle out parameters
    
        If out is a dask.array then this overwrites the contents of that array with
        the result
        """
        out = _elemwise_normalize_out(out)
        if isinstance(out, Array):
            if out.shape != result.shape:
                raise ValueError(
                    "Mismatched shapes between result and out parameter. "
                    "out=%s, result=%s" % (str(out.shape), str(result.shape))
                )
            out._chunks = result.chunks
            out.dask = result.dask
            out._meta = result._meta
            out._name = result.name
            return out
        else:
            return result
    
    
    def _enforce_dtype(*args, **kwargs):
        """Calls a function and converts its result to the given dtype.
    
        The parameters have deliberately been given unwieldy names to avoid
        clashes with keyword arguments consumed by blockwise
    
        A dtype of `object` is treated as a special case and not enforced,
        because it is used as a dummy value in some places when the result will
        not be a block in an Array.
    
        Parameters
        ----------
        enforce_dtype : dtype
            Result dtype
        enforce_dtype_function : callable
            The wrapped function, which will be passed the remaining arguments
        """
        dtype = kwargs.pop("enforce_dtype")
        function = kwargs.pop("enforce_dtype_function")
    
        result = function(*args, **kwargs)
        if hasattr(result, "dtype") and dtype != result.dtype and dtype != object:
            if not np.can_cast(result, dtype, casting="same_kind"):
                raise ValueError(
                    "Inferred dtype from function %r was %r "
                    "but got %r, which can't be cast using "
                    "casting='same_kind'"
                    % (funcname(function), str(dtype), str(result.dtype))
                )
            if np.isscalar(result):
                # scalar astype method doesn't take the keyword arguments, so
                # have to convert via 0-dimensional array and back.
                result = result.astype(dtype)
            else:
                try:
                    result = result.astype(dtype, copy=False)
                except TypeError:
                    # Missing copy kwarg
                    result = result.astype(dtype)
        return result
    
    
    def broadcast_to(x, shape, chunks=None, meta=None):
        """Broadcast an array to a new shape.
    
        Parameters
        ----------
        x : array_like
            The array to broadcast.
        shape : tuple
            The shape of the desired array.
        chunks : tuple, optional
            If provided, then the result will use these chunks instead of the same
            chunks as the source array. Setting chunks explicitly as part of
            broadcast_to is more efficient than rechunking afterwards. Chunks are
            only allowed to differ from the original shape along dimensions that
            are new on the result or have size 1 the input array.
        meta : empty ndarray
            empty ndarray created with same NumPy backend, ndim and dtype as the
            Dask Array being created (overrides dtype)
    
        Returns
        -------
        broadcast : dask array
    
        See Also
        --------
        :func:`numpy.broadcast_to`
        """
        x = asarray(x)
        shape = tuple(shape)
    
        if meta is None:
            meta = meta_from_array(x)
    
        if x.shape == shape and (chunks is None or chunks == x.chunks):
            return x
    
        ndim_new = len(shape) - x.ndim
        if ndim_new < 0 or any(
            new != old for new, old in zip(shape[ndim_new:], x.shape) if old != 1
        ):
            raise ValueError(f"cannot broadcast shape {x.shape} to shape {shape}")
    
        if chunks is None:
            chunks = tuple((s,) for s in shape[:ndim_new]) + tuple(
                bd if old > 1 else (new,)
                for bd, old, new in zip(x.chunks, x.shape, shape[ndim_new:])
            )
        else:
            chunks = normalize_chunks(
                chunks, shape, dtype=x.dtype, previous_chunks=x.chunks
            )
            for old_bd, new_bd in zip(x.chunks, chunks[ndim_new:]):
                if old_bd != new_bd and old_bd != (1,):
                    raise ValueError(
                        "cannot broadcast chunks %s to chunks %s: "
                        "new chunks must either be along a new "
                        "dimension or a dimension of size 1" % (x.chunks, chunks)
                    )
    
        name = "broadcast_to-" + tokenize(x, shape, chunks)
        dsk = {}
    
        enumerated_chunks = product(*(enumerate(bds) for bds in chunks))
        for new_index, chunk_shape in (zip(*ec) for ec in enumerated_chunks):
            old_index = tuple(
                0 if bd == (1,) else i for bd, i in zip(x.chunks, new_index[ndim_new:])
            )
            old_key = (x.name,) + old_index
            new_key = (name,) + new_index
            dsk[new_key] = (np.broadcast_to, old_key, quote(chunk_shape))
    
        graph = HighLevelGraph.from_collections(name, dsk, dependencies=[x])
        return Array(graph, name, chunks, dtype=x.dtype, meta=meta)
    
    
    @derived_from(np)
    def broadcast_arrays(*args, subok=False):
        subok = bool(subok)
    
        to_array = asanyarray if subok else asarray
        args = tuple(to_array(e) for e in args)
    
        # Unify uneven chunking
        inds = [list(reversed(range(x.ndim))) for x in args]
        uc_args = concat(zip(args, inds))
        _, args = unify_chunks(*uc_args, warn=False)
    
        shape = broadcast_shapes(*(e.shape for e in args))
        chunks = broadcast_chunks(*(e.chunks for e in args))
    
        if NUMPY_GE_200:
            result = tuple(broadcast_to(e, shape=shape, chunks=chunks) for e in args)
        else:
            result = [broadcast_to(e, shape=shape, chunks=chunks) for e in args]
    
        return result
    
    
    def offset_func(func, offset, *args):
        """Offsets inputs by offset
    
        >>> double = lambda x: x * 2
        >>> f = offset_func(double, (10,))
        >>> f(1)
        22
        >>> f(300)
        620
        """
    
        def _offset(*args):
            args2 = list(map(add, args, offset))
            return func(*args2)
    
        with contextlib.suppress(Exception):
            _offset.__name__ = "offset_" + func.__name__
    
        return _offset
    
    
    def chunks_from_arrays(arrays):
        """Chunks tuple from nested list of arrays
    
        >>> x = np.array([1, 2])
        >>> chunks_from_arrays([x, x])
        ((2, 2),)
    
        >>> x = np.array([[1, 2]])
        >>> chunks_from_arrays([[x], [x]])
        ((1, 1), (2,))
    
        >>> x = np.array([[1, 2]])
        >>> chunks_from_arrays([[x, x]])
        ((1,), (2, 2))
    
        >>> chunks_from_arrays([1, 1])
        ((1, 1),)
        """
        if not arrays:
            return ()
        result = []
        dim = 0
    
        def shape(x):
            try:
                return x.shape if x.shape else (1,)
            except AttributeError:
                return (1,)
    
        while isinstance(arrays, (list, tuple)):
>           result.append(tuple(shape(deepfirst(a))[dim] for a in arrays))
E           IndexError: tuple index out of range

../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/array/core.py:5237: IndexError

Check warning on line 0 in distributed.tests.test_client

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

8 out of 9 runs failed: test_release_persisted_collection (distributed.tests.test_client)

artifacts/macos-latest-3.12-default-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-ci1/pytest.xml [took 0s]
Raw output
IndexError: tuple index out of range
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:38475', workers: 0, cores: 0, tasks: 0>
a = <Worker 'tcp://127.0.0.1:42523', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>
b = <Worker 'tcp://127.0.0.1:39277', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>

    @gen_cluster(client=True)
    async def test_release_persisted_collection(c, s, a, b):
        np = pytest.importorskip("numpy")
        da = pytest.importorskip("dask.array")
    
        arr = c.persist(da.random.random((10,), chunks=(10,)))
    
        await wait(arr)
    
        _release_persisted(arr)
        while s.tasks:
            await asyncio.sleep(0.01)
    
        with pytest.raises(CancelledError):
>           await c.compute(arr)

distributed/tests/test_client.py:8235: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/array/core.py:1284: in finalize
    return concatenate3(results)
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/array/core.py:5452: in concatenate3
    chunks = chunks_from_arrays(arrays)
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/array/core.py:5237: in chunks_from_arrays
    result.append(tuple(shape(deepfirst(a))[dim] for a in arrays))
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import contextlib
    import math
    import operator
    import os
    import pickle
    import re
    import sys
    import traceback
    import uuid
    import warnings
    from bisect import bisect
    from collections import defaultdict
    from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence
    from functools import lru_cache, partial, reduce, wraps
    from itertools import product, zip_longest
    from numbers import Integral, Number
    from operator import add, mul
    from threading import Lock
    from typing import Any, Literal, TypeVar, Union, cast
    
    import numpy as np
    from numpy.typing import ArrayLike
    from packaging.version import Version
    from tlz import accumulate, concat, first, groupby, partition
    from tlz.curried import pluck
    from toolz import frequencies
    
    from dask import compute, config, core
    from dask.array import chunk
    from dask.array.chunk import getitem
    from dask.array.chunk_types import is_valid_array_chunk, is_valid_chunk_type
    
    # Keep einsum_lookup and tensordot_lookup here for backwards compatibility
    from dask.array.dispatch import (  # noqa: F401
        concatenate_lookup,
        einsum_lookup,
        tensordot_lookup,
    )
    from dask.array.numpy_compat import NUMPY_GE_200, _Recurser
    from dask.array.slicing import replace_ellipsis, setitem_array, slice_array
    from dask.array.utils import compute_meta, meta_from_array
    from dask.base import (
        DaskMethodsMixin,
        compute_as_if_collection,
        dont_optimize,
        is_dask_collection,
        named_schedulers,
        persist,
        tokenize,
    )
    from dask.blockwise import blockwise as core_blockwise
    from dask.blockwise import broadcast_dimensions
    from dask.context import globalmethod
    from dask.core import quote
    from dask.delayed import Delayed, delayed
    from dask.highlevelgraph import HighLevelGraph, MaterializedLayer
    from dask.layers import ArrayBlockIdDep, ArraySliceDep, ArrayValuesDep, reshapelist
    from dask.sizeof import sizeof
    from dask.typing import Graph, Key, NestedKeys
    from dask.utils import (
        IndexCallable,
        SerializableLock,
        cached_cumsum,
        cached_property,
        concrete,
        derived_from,
        format_bytes,
        funcname,
        has_keyword,
        is_arraylike,
        is_dataframe_like,
        is_index_like,
        is_integer,
        is_series_like,
        maybe_pluralize,
        ndeepmap,
        ndimlist,
        parse_bytes,
        typename,
    )
    from dask.widgets import get_template
    
    T_IntOrNaN = Union[int, float]  # Should be Union[int, Literal[np.nan]]
    
    DEFAULT_GET = named_schedulers.get("threads", named_schedulers["sync"])
    
    unknown_chunk_message = (
        "\n\n"
        "A possible solution: "
        "https://docs.dask.org/en/latest/array-chunks.html#unknown-chunks\n"
        "Summary: to compute chunks sizes, use\n\n"
        "   x.compute_chunk_sizes()  # for Dask Array `x`\n"
        "   ddf.to_dask_array(lengths=True)  # for Dask DataFrame `ddf`"
    )
    
    
    class PerformanceWarning(Warning):
        """A warning given when bad chunking may cause poor performance"""
    
    
    def getter(a, b, asarray=True, lock=None):
        if isinstance(b, tuple) and any(x is None for x in b):
            b2 = tuple(x for x in b if x is not None)
            b3 = tuple(
                None if x is None else slice(None, None)
                for x in b
                if not isinstance(x, Integral)
            )
            return getter(a, b2, asarray=asarray, lock=lock)[b3]
    
        if lock:
            lock.acquire()
        try:
            c = a[b]
            # Below we special-case `np.matrix` to force a conversion to
            # `np.ndarray` and preserve original Dask behavior for `getter`,
            # as for all purposes `np.matrix` is array-like and thus
            # `is_arraylike` evaluates to `True` in that case.
            if asarray and (not is_arraylike(c) or isinstance(c, np.matrix)):
                c = np.asarray(c)
        finally:
            if lock:
                lock.release()
        return c
    
    
    def getter_nofancy(a, b, asarray=True, lock=None):
        """A simple wrapper around ``getter``.
    
        Used to indicate to the optimization passes that the backend doesn't
        support fancy indexing.
        """
        return getter(a, b, asarray=asarray, lock=lock)
    
    
    def getter_inline(a, b, asarray=True, lock=None):
        """A getter function that optimizations feel comfortable inlining
    
        Slicing operations with this function may be inlined into a graph, such as
        in the following rewrite
    
        **Before**
    
        >>> a = x[:10]  # doctest: +SKIP
        >>> b = a + 1  # doctest: +SKIP
        >>> c = a * 2  # doctest: +SKIP
    
        **After**
    
        >>> b = x[:10] + 1  # doctest: +SKIP
        >>> c = x[:10] * 2  # doctest: +SKIP
    
        This inlining can be relevant to operations when running off of disk.
        """
        return getter(a, b, asarray=asarray, lock=lock)
    
    
    from dask.array.optimization import fuse_slice, optimize
    
    # __array_function__ dict for mapping aliases and mismatching names
    _HANDLED_FUNCTIONS = {}
    
    
    def implements(*numpy_functions):
        """Register an __array_function__ implementation for dask.array.Array
    
        Register that a function implements the API of a NumPy function (or several
        NumPy functions in case of aliases) which is handled with
        ``__array_function__``.
    
        Parameters
        ----------
        \\*numpy_functions : callables
            One or more NumPy functions that are handled by ``__array_function__``
            and will be mapped by `implements` to a `dask.array` function.
        """
    
        def decorator(dask_func):
            for numpy_function in numpy_functions:
                _HANDLED_FUNCTIONS[numpy_function] = dask_func
    
            return dask_func
    
        return decorator
    
    
    def _should_delegate(self, other) -> bool:
        """Check whether Dask should delegate to the other.
        This implementation follows NEP-13:
        https://numpy.org/neps/nep-0013-ufunc-overrides.html#behavior-in-combination-with-python-s-binary-operations
        """
        if hasattr(other, "__array_ufunc__") and other.__array_ufunc__ is None:
            return True
        elif (
            hasattr(other, "__array_ufunc__")
            and not is_valid_array_chunk(other)
            # don't delegate to our own parent classes
            and not isinstance(self, type(other))
            and type(self) is not type(other)
        ):
            return True
        return False
    
    
    def check_if_handled_given_other(f):
        """Check if method is handled by Dask given type of other
    
        Ensures proper deferral to upcast types in dunder operations without
        assuming unknown types are automatically downcast types.
        """
    
        @wraps(f)
        def wrapper(self, other):
            if _should_delegate(self, other):
                return NotImplemented
            else:
                return f(self, other)
    
        return wrapper
    
    
    def slices_from_chunks(chunks):
        """Translate chunks tuple to a set of slices in product order
    
        >>> slices_from_chunks(((2, 2), (3, 3, 3)))  # doctest: +NORMALIZE_WHITESPACE
         [(slice(0, 2, None), slice(0, 3, None)),
          (slice(0, 2, None), slice(3, 6, None)),
          (slice(0, 2, None), slice(6, 9, None)),
          (slice(2, 4, None), slice(0, 3, None)),
          (slice(2, 4, None), slice(3, 6, None)),
          (slice(2, 4, None), slice(6, 9, None))]
        """
        cumdims = [cached_cumsum(bds, initial_zero=True) for bds in chunks]
        slices = [
            [slice(s, s + dim) for s, dim in zip(starts, shapes)]
            for starts, shapes in zip(cumdims, chunks)
        ]
        return list(product(*slices))
    
    
    def graph_from_arraylike(
        arr,  # Any array-like which supports slicing
        chunks,
        shape,
        name,
        getitem=getter,
        lock=False,
        asarray=True,
        dtype=None,
        inline_array=False,
    ) -> HighLevelGraph:
        """
        HighLevelGraph for slicing chunks from an array-like according to a chunk pattern.
    
        If ``inline_array`` is True, this make a Blockwise layer of slicing tasks where the
        array-like is embedded into every task.,
    
        If ``inline_array`` is False, this inserts the array-like as a standalone value in
        a MaterializedLayer, then generates a Blockwise layer of slicing tasks that refer
        to it.
    
        >>> dict(graph_from_arraylike(arr, chunks=(2, 3), shape=(4, 6), name="X", inline_array=True))  # doctest: +SKIP
        {(arr, 0, 0): (getter, arr, (slice(0, 2), slice(0, 3))),
         (arr, 1, 0): (getter, arr, (slice(2, 4), slice(0, 3))),
         (arr, 1, 1): (getter, arr, (slice(2, 4), slice(3, 6))),
         (arr, 0, 1): (getter, arr, (slice(0, 2), slice(3, 6)))}
    
        >>> dict(  # doctest: +SKIP
                graph_from_arraylike(arr, chunks=((2, 2), (3, 3)), shape=(4,6), name="X", inline_array=False)
            )
        {"original-X": arr,
         ('X', 0, 0): (getter, 'original-X', (slice(0, 2), slice(0, 3))),
         ('X', 1, 0): (getter, 'original-X', (slice(2, 4), slice(0, 3))),
         ('X', 1, 1): (getter, 'original-X', (slice(2, 4), slice(3, 6))),
         ('X', 0, 1): (getter, 'original-X', (slice(0, 2), slice(3, 6)))}
        """
        chunks = normalize_chunks(chunks, shape, dtype=dtype)
        out_ind = tuple(range(len(shape)))
    
        if (
            has_keyword(getitem, "asarray")
            and has_keyword(getitem, "lock")
            and (not asarray or lock)
        ):
            kwargs = {"asarray": asarray, "lock": lock}
        else:
            # Common case, drop extra parameters
            kwargs = {}
    
        if inline_array:
            layer = core_blockwise(
                getitem,
                name,
                out_ind,
                arr,
                None,
                ArraySliceDep(chunks),
                out_ind,
                numblocks={},
                **kwargs,
            )
            return HighLevelGraph.from_collections(name, layer)
        else:
            original_name = "original-" + name
    
            layers = {}
            layers[original_name] = MaterializedLayer({original_name: arr})
            layers[name] = core_blockwise(
                getitem,
                name,
                out_ind,
                original_name,
                None,
                ArraySliceDep(chunks),
                out_ind,
                numblocks={},
                **kwargs,
            )
    
            deps = {
                original_name: set(),
                name: {original_name},
            }
            return HighLevelGraph(layers, deps)
    
    
    def dotmany(A, B, leftfunc=None, rightfunc=None, **kwargs):
        """Dot product of many aligned chunks
    
        >>> x = np.array([[1, 2], [1, 2]])
        >>> y = np.array([[10, 20], [10, 20]])
        >>> dotmany([x, x, x], [y, y, y])
        array([[ 90, 180],
               [ 90, 180]])
    
        Optionally pass in functions to apply to the left and right chunks
    
        >>> dotmany([x, x, x], [y, y, y], rightfunc=np.transpose)
        array([[150, 150],
               [150, 150]])
        """
        if leftfunc:
            A = map(leftfunc, A)
        if rightfunc:
            B = map(rightfunc, B)
        return sum(map(partial(np.dot, **kwargs), A, B))
    
    
    def _concatenate2(arrays, axes=None):
        """Recursively concatenate nested lists of arrays along axes
    
        Each entry in axes corresponds to each level of the nested list.  The
        length of axes should correspond to the level of nesting of arrays.
        If axes is an empty list or tuple, return arrays, or arrays[0] if
        arrays is a list.
    
        >>> x = np.array([[1, 2], [3, 4]])
        >>> _concatenate2([x, x], axes=[0])
        array([[1, 2],
               [3, 4],
               [1, 2],
               [3, 4]])
    
        >>> _concatenate2([x, x], axes=[1])
        array([[1, 2, 1, 2],
               [3, 4, 3, 4]])
    
        >>> _concatenate2([[x, x], [x, x]], axes=[0, 1])
        array([[1, 2, 1, 2],
               [3, 4, 3, 4],
               [1, 2, 1, 2],
               [3, 4, 3, 4]])
    
        Supports Iterators
        >>> _concatenate2(iter([x, x]), axes=[1])
        array([[1, 2, 1, 2],
               [3, 4, 3, 4]])
    
        Special Case
        >>> _concatenate2([x, x], axes=())
        array([[1, 2],
               [3, 4]])
        """
        if axes is None:
            axes = []
    
        if axes == ():
            if isinstance(arrays, list):
                return arrays[0]
            else:
                return arrays
    
        if isinstance(arrays, Iterator):
            arrays = list(arrays)
        if not isinstance(arrays, (list, tuple)):
            return arrays
        if len(axes) > 1:
            arrays = [_concatenate2(a, axes=axes[1:]) for a in arrays]
        concatenate = concatenate_lookup.dispatch(
            type(max(arrays, key=lambda x: getattr(x, "__array_priority__", 0)))
        )
        if isinstance(arrays[0], dict):
            # Handle concatenation of `dict`s, used as a replacement for structured
            # arrays when that's not supported by the array library (e.g., CuPy).
            keys = list(arrays[0].keys())
            assert all(list(a.keys()) == keys for a in arrays)
            ret = dict()
            for k in keys:
                ret[k] = concatenate(list(a[k] for a in arrays), axis=axes[0])
            return ret
        else:
            return concatenate(arrays, axis=axes[0])
    
    
    def apply_infer_dtype(func, args, kwargs, funcname, suggest_dtype="dtype", nout=None):
        """
        Tries to infer output dtype of ``func`` for a small set of input arguments.
    
        Parameters
        ----------
        func: Callable
            Function for which output dtype is to be determined
    
        args: List of array like
            Arguments to the function, which would usually be used. Only attributes
            ``ndim`` and ``dtype`` are used.
    
        kwargs: dict
            Additional ``kwargs`` to the ``func``
    
        funcname: String
            Name of calling function to improve potential error messages
    
        suggest_dtype: None/False or String
            If not ``None`` adds suggestion to potential error message to specify a dtype
            via the specified kwarg. Defaults to ``'dtype'``.
    
        nout: None or Int
            ``None`` if function returns single output, integer if many.
            Defaults to ``None``.
    
        Returns
        -------
        : dtype or List of dtype
            One or many dtypes (depending on ``nout``)
        """
        from dask.array.utils import meta_from_array
    
        # make sure that every arg is an evaluated array
        args = [
            (
                np.ones_like(meta_from_array(x), shape=((1,) * x.ndim), dtype=x.dtype)
                if is_arraylike(x)
                else x
            )
            for x in args
        ]
        try:
            with np.errstate(all="ignore"):
                o = func(*args, **kwargs)
        except Exception as e:
            exc_type, exc_value, exc_traceback = sys.exc_info()
            tb = "".join(traceback.format_tb(exc_traceback))
            suggest = (
                (
                    "Please specify the dtype explicitly using the "
                    "`{dtype}` kwarg.\n\n".format(dtype=suggest_dtype)
                )
                if suggest_dtype
                else ""
            )
            msg = (
                f"`dtype` inference failed in `{funcname}`.\n\n"
                f"{suggest}"
                "Original error is below:\n"
                "------------------------\n"
                f"{e!r}\n\n"
                "Traceback:\n"
                "---------\n"
                f"{tb}"
            )
        else:
            msg = None
        if msg is not None:
            raise ValueError(msg)
        return getattr(o, "dtype", type(o)) if nout is None else tuple(e.dtype for e in o)
    
    
    def normalize_arg(x):
        """Normalize user provided arguments to blockwise or map_blocks
    
        We do a few things:
    
        1.  If they are string literals that might collide with blockwise_token then we
            quote them
        2.  IF they are large (as defined by sizeof) then we put them into the
            graph on their own by using dask.delayed
        """
        if is_dask_collection(x):
            return x
        elif isinstance(x, str) and re.match(r"_\d+", x):
            return delayed(x)
        elif isinstance(x, list) and len(x) >= 10:
            return delayed(x)
        elif sizeof(x) > 1e6:
            return delayed(x)
        else:
            return x
    
    
    def _pass_extra_kwargs(func, keys, *args, **kwargs):
        """Helper for :func:`dask.array.map_blocks` to pass `block_info` or `block_id`.
    
        For each element of `keys`, a corresponding element of args is changed
        to a keyword argument with that key, before all arguments re passed on
        to `func`.
        """
        kwargs.update(zip(keys, args))
        return func(*args[len(keys) :], **kwargs)
    
    
    def map_blocks(
        func,
        *args,
        name=None,
        token=None,
        dtype=None,
        chunks=None,
        drop_axis=None,
        new_axis=None,
        enforce_ndim=False,
        meta=None,
        **kwargs,
    ):
        """Map a function across all blocks of a dask array.
    
        Note that ``map_blocks`` will attempt to automatically determine the output
        array type by calling ``func`` on 0-d versions of the inputs. Please refer to
        the ``meta`` keyword argument below if you expect that the function will not
        succeed when operating on 0-d arrays.
    
        Parameters
        ----------
        func : callable
            Function to apply to every block in the array.
            If ``func`` accepts ``block_info=`` or ``block_id=``
            as keyword arguments, these will be passed dictionaries
            containing information about input and output chunks/arrays
            during computation. See examples for details.
        args : dask arrays or other objects
        dtype : np.dtype, optional
            The ``dtype`` of the output array. It is recommended to provide this.
            If not provided, will be inferred by applying the function to a small
            set of fake data.
        chunks : tuple, optional
            Chunk shape of resulting blocks if the function does not preserve
            shape. If not provided, the resulting array is assumed to have the same
            block structure as the first input array.
        drop_axis : number or iterable, optional
            Dimensions lost by the function.
        new_axis : number or iterable, optional
            New dimensions created by the function. Note that these are applied
            after ``drop_axis`` (if present). The size of each chunk along this
            dimension will be set to 1. Please specify ``chunks`` if the individual
            chunks have a different size.
        enforce_ndim : bool, default False
            Whether to enforce at runtime that the dimensionality of the array
            produced by ``func`` actually matches that of the array returned by
            ``map_blocks``.
            If True, this will raise an error when there is a mismatch.
        token : string, optional
            The key prefix to use for the output array. If not provided, will be
            determined from the function name.
        name : string, optional
            The key name to use for the output array. Note that this fully
            specifies the output key name, and must be unique. If not provided,
            will be determined by a hash of the arguments.
        meta : array-like, optional
            The ``meta`` of the output array, when specified is expected to be an
            array of the same type and dtype of that returned when calling ``.compute()``
            on the array returned by this function. When not provided, ``meta`` will be
            inferred by applying the function to a small set of fake data, usually a
            0-d array. It's important to ensure that ``func`` can successfully complete
            computation without raising exceptions when 0-d is passed to it, providing
            ``meta`` will be required otherwise. If the output type is known beforehand
            (e.g., ``np.ndarray``, ``cupy.ndarray``), an empty array of such type dtype
            can be passed, for example: ``meta=np.array((), dtype=np.int32)``.
        **kwargs :
            Other keyword arguments to pass to function. Values must be constants
            (not dask.arrays)
    
        See Also
        --------
        dask.array.map_overlap : Generalized operation with overlap between neighbors.
        dask.array.blockwise : Generalized operation with control over block alignment.
    
        Examples
        --------
        >>> import dask.array as da
        >>> x = da.arange(6, chunks=3)
    
        >>> x.map_blocks(lambda x: x * 2).compute()
        array([ 0,  2,  4,  6,  8, 10])
    
        The ``da.map_blocks`` function can also accept multiple arrays.
    
        >>> d = da.arange(5, chunks=2)
        >>> e = da.arange(5, chunks=2)
    
        >>> f = da.map_blocks(lambda a, b: a + b**2, d, e)
        >>> f.compute()
        array([ 0,  2,  6, 12, 20])
    
        If the function changes shape of the blocks then you must provide chunks
        explicitly.
    
        >>> y = x.map_blocks(lambda x: x[::2], chunks=((2, 2),))
    
        You have a bit of freedom in specifying chunks.  If all of the output chunk
        sizes are the same, you can provide just that chunk size as a single tuple.
    
        >>> a = da.arange(18, chunks=(6,))
        >>> b = a.map_blocks(lambda x: x[:3], chunks=(3,))
    
        If the function changes the dimension of the blocks you must specify the
        created or destroyed dimensions.
    
        >>> b = a.map_blocks(lambda x: x[None, :, None], chunks=(1, 6, 1),
        ...                  new_axis=[0, 2])
    
        If ``chunks`` is specified but ``new_axis`` is not, then it is inferred to
        add the necessary number of axes on the left.
    
        Note that ``map_blocks()`` will concatenate chunks along axes specified by
        the keyword parameter ``drop_axis`` prior to applying the function.
        This is illustrated in the figure below:
    
        .. image:: /images/map_blocks_drop_axis.png
    
        Due to memory-size-constraints, it is often not advisable to use ``drop_axis``
        on an axis that is chunked.  In that case, it is better not to use
        ``map_blocks`` but rather
        ``dask.array.reduction(..., axis=dropped_axes, concatenate=False)`` which
        maintains a leaner memory footprint while it drops any axis.
    
        Map_blocks aligns blocks by block positions without regard to shape. In the
        following example we have two arrays with the same number of blocks but
        with different shape and chunk sizes.
    
        >>> x = da.arange(1000, chunks=(100,))
        >>> y = da.arange(100, chunks=(10,))
    
        The relevant attribute to match is numblocks.
    
        >>> x.numblocks
        (10,)
        >>> y.numblocks
        (10,)
    
        If these match (up to broadcasting rules) then we can map arbitrary
        functions across blocks
    
        >>> def func(a, b):
        ...     return np.array([a.max(), b.max()])
    
        >>> da.map_blocks(func, x, y, chunks=(2,), dtype='i8')
        dask.array<func, shape=(20,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray>
    
        >>> _.compute()
        array([ 99,   9, 199,  19, 299,  29, 399,  39, 499,  49, 599,  59, 699,
                69, 799,  79, 899,  89, 999,  99])
    
        Your block function can get information about where it is in the array by
        accepting a special ``block_info`` or ``block_id`` keyword argument.
        During computation, they will contain information about each of the input
        and output chunks (and dask arrays) relevant to each call of ``func``.
    
        >>> def func(block_info=None):
        ...     pass
    
        This will receive the following information:
    
        >>> block_info  # doctest: +SKIP
        {0: {'shape': (1000,),
             'num-chunks': (10,),
             'chunk-location': (4,),
             'array-location': [(400, 500)]},
         None: {'shape': (1000,),
                'num-chunks': (10,),
                'chunk-location': (4,),
                'array-location': [(400, 500)],
                'chunk-shape': (100,),
                'dtype': dtype('float64')}}
    
        The keys to the ``block_info`` dictionary indicate which is the input and
        output Dask array:
    
        - **Input Dask array(s):** ``block_info[0]`` refers to the first input Dask array.
          The dictionary key is ``0`` because that is the argument index corresponding
          to the first input Dask array.
          In cases where multiple Dask arrays have been passed as input to the function,
          you can access them with the number corresponding to the input argument,
          eg: ``block_info[1]``, ``block_info[2]``, etc.
          (Note that if you pass multiple Dask arrays as input to map_blocks,
          the arrays must match each other by having matching numbers of chunks,
          along corresponding dimensions up to broadcasting rules.)
        - **Output Dask array:** ``block_info[None]`` refers to the output Dask array,
          and contains information about the output chunks.
          The output chunk shape and dtype may may be different than the input chunks.
    
        For each dask array, ``block_info`` describes:
    
        - ``shape``: the shape of the full Dask array,
        - ``num-chunks``: the number of chunks of the full array in each dimension,
        - ``chunk-location``: the chunk location (for example the fourth chunk over
          in the first dimension), and
        - ``array-location``: the array location within the full Dask array
          (for example the slice corresponding to ``40:50``).
    
        In addition to these, there are two extra parameters described by
        ``block_info`` for the output array (in ``block_info[None]``):
    
        - ``chunk-shape``: the output chunk shape, and
        - ``dtype``: the output dtype.
    
        These features can be combined to synthesize an array from scratch, for
        example:
    
        >>> def func(block_info=None):
        ...     loc = block_info[None]['array-location'][0]
        ...     return np.arange(loc[0], loc[1])
    
        >>> da.map_blocks(func, chunks=((4, 4),), dtype=np.float64)
        dask.array<func, shape=(8,), dtype=float64, chunksize=(4,), chunktype=numpy.ndarray>
    
        >>> _.compute()
        array([0, 1, 2, 3, 4, 5, 6, 7])
    
        ``block_id`` is similar to ``block_info`` but contains only the ``chunk_location``:
    
        >>> def func(block_id=None):
        ...     pass
    
        This will receive the following information:
    
        >>> block_id  # doctest: +SKIP
        (4, 3)
    
        You may specify the key name prefix of the resulting task in the graph with
        the optional ``token`` keyword argument.
    
        >>> x.map_blocks(lambda x: x + 1, name='increment')
        dask.array<increment, shape=(1000,), dtype=int64, chunksize=(100,), chunktype=numpy.ndarray>
    
        For functions that may not handle 0-d arrays, it's also possible to specify
        ``meta`` with an empty array matching the type of the expected result. In
        the example below, ``func`` will result in an ``IndexError`` when computing
        ``meta``:
    
        >>> rng = da.random.default_rng()
        >>> da.map_blocks(lambda x: x[2], rng.random(5), meta=np.array(()))
        dask.array<lambda, shape=(5,), dtype=float64, chunksize=(5,), chunktype=numpy.ndarray>
    
        Similarly, it's possible to specify a non-NumPy array to ``meta``, and provide
        a ``dtype``:
    
        >>> import cupy  # doctest: +SKIP
        >>> rng = da.random.default_rng(cupy.random.default_rng())  # doctest: +SKIP
        >>> dt = np.float32
        >>> da.map_blocks(lambda x: x[2], rng.random(5, dtype=dt), meta=cupy.array((), dtype=dt))  # doctest: +SKIP
        dask.array<lambda, shape=(5,), dtype=float32, chunksize=(5,), chunktype=cupy.ndarray>
        """
        if drop_axis is None:
            drop_axis = []
    
        if not callable(func):
            msg = (
                "First argument must be callable function, not %s\n"
                "Usage:   da.map_blocks(function, x)\n"
                "   or:   da.map_blocks(function, x, y, z)"
            )
            raise TypeError(msg % type(func).__name__)
        if token:
            warnings.warn(
                "The `token=` keyword to `map_blocks` has been moved to `name=`. "
                "Please use `name=` instead as the `token=` keyword will be removed "
                "in a future release.",
                category=FutureWarning,
            )
            name = token
    
        name = f"{name or funcname(func)}-{tokenize(func, dtype, chunks, drop_axis, new_axis, *args, **kwargs)}"
        new_axes = {}
    
        if isinstance(drop_axis, Number):
            drop_axis = [drop_axis]
        if isinstance(new_axis, Number):
            new_axis = [new_axis]  # TODO: handle new_axis
    
        arrs = [a for a in args if isinsta…tack
        """
        from dask.array import wrap
    
        seq = [asarray(a, allow_unknown_chunksizes=allow_unknown_chunksizes) for a in seq]
    
        if not seq:
            raise ValueError("Need array(s) to concatenate")
    
        if axis is None:
            seq = [a.flatten() for a in seq]
            axis = 0
    
        seq_metas = [meta_from_array(s) for s in seq]
        _concatenate = concatenate_lookup.dispatch(
            type(max(seq_metas, key=lambda x: getattr(x, "__array_priority__", 0)))
        )
        meta = _concatenate(seq_metas, axis=axis)
    
        # Promote types to match meta
        seq = [a.astype(meta.dtype) for a in seq]
    
        # Find output array shape
        ndim = len(seq[0].shape)
        shape = tuple(
            sum(a.shape[i] for a in seq) if i == axis else seq[0].shape[i]
            for i in range(ndim)
        )
    
        # Drop empty arrays
        seq2 = [a for a in seq if a.size]
        if not seq2:
            seq2 = seq
    
        if axis < 0:
            axis = ndim + axis
        if axis >= ndim:
            msg = (
                "Axis must be less than than number of dimensions"
                "\nData has %d dimensions, but got axis=%d"
            )
            raise ValueError(msg % (ndim, axis))
    
        n = len(seq2)
        if n == 0:
            try:
                return wrap.empty_like(meta, shape=shape, chunks=shape, dtype=meta.dtype)
            except TypeError:
                return wrap.empty(shape, chunks=shape, dtype=meta.dtype)
        elif n == 1:
            return seq2[0]
    
        if not allow_unknown_chunksizes and not all(
            i == axis or all(x.shape[i] == seq2[0].shape[i] for x in seq2)
            for i in range(ndim)
        ):
            if any(map(np.isnan, seq2[0].shape)):
                raise ValueError(
                    "Tried to concatenate arrays with unknown"
                    " shape %s.\n\nTwo solutions:\n"
                    "  1. Force concatenation pass"
                    " allow_unknown_chunksizes=True.\n"
                    "  2. Compute shapes with "
                    "[x.compute_chunk_sizes() for x in seq]" % str(seq2[0].shape)
                )
            raise ValueError("Shapes do not align: %s", [x.shape for x in seq2])
    
        inds = [list(range(ndim)) for i in range(n)]
        for i, ind in enumerate(inds):
            ind[axis] = -(i + 1)
    
        uc_args = list(concat(zip(seq2, inds)))
        _, seq2 = unify_chunks(*uc_args, warn=False)
    
        bds = [a.chunks for a in seq2]
    
        chunks = (
            seq2[0].chunks[:axis]
            + (sum((bd[axis] for bd in bds), ()),)
            + seq2[0].chunks[axis + 1 :]
        )
    
        cum_dims = [0] + list(accumulate(add, [len(a.chunks[axis]) for a in seq2]))
    
        names = [a.name for a in seq2]
    
        name = "concatenate-" + tokenize(names, axis)
        keys = list(product([name], *[range(len(bd)) for bd in chunks]))
    
        values = [
            (names[bisect(cum_dims, key[axis + 1]) - 1],)
            + key[1 : axis + 1]
            + (key[axis + 1] - cum_dims[bisect(cum_dims, key[axis + 1]) - 1],)
            + key[axis + 2 :]
            for key in keys
        ]
    
        dsk = dict(zip(keys, values))
        graph = HighLevelGraph.from_collections(name, dsk, dependencies=seq2)
    
        return Array(graph, name, chunks, meta=meta)
    
    
    def load_store_chunk(
        x: Any,
        out: Any,
        index: slice,
        lock: Any,
        return_stored: bool,
        load_stored: bool,
    ):
        """
        A function inserted in a Dask graph for storing a chunk.
    
        Parameters
        ----------
        x: array-like
            An array (potentially a NumPy one)
        out: array-like
            Where to store results.
        index: slice-like
            Where to store result from ``x`` in ``out``.
        lock: Lock-like or False
            Lock to use before writing to ``out``.
        return_stored: bool
            Whether to return ``out``.
        load_stored: bool
            Whether to return the array stored in ``out``.
            Ignored if ``return_stored`` is not ``True``.
    
        Returns
        -------
    
        If return_stored=True and load_stored=False
            out
        If return_stored=True and load_stored=True
            out[index]
        If return_stored=False and compute=False
            None
    
        Examples
        --------
    
        >>> a = np.ones((5, 6))
        >>> b = np.empty(a.shape)
        >>> load_store_chunk(a, b, (slice(None), slice(None)), False, False, False)
        """
        if lock:
            lock.acquire()
        try:
            if x is not None and x.size != 0:
                if is_arraylike(x):
                    out[index] = x
                else:
                    out[index] = np.asanyarray(x)
    
            if return_stored and load_stored:
                return out[index]
            elif return_stored and not load_stored:
                return out
            else:
                return None
        finally:
            if lock:
                lock.release()
    
    
    def store_chunk(
        x: ArrayLike, out: ArrayLike, index: slice, lock: Any, return_stored: bool
    ):
        return load_store_chunk(x, out, index, lock, return_stored, False)
    
    
    A = TypeVar("A", bound=ArrayLike)
    
    
    def load_chunk(out: A, index: slice, lock: Any) -> A:
        return load_store_chunk(None, out, index, lock, True, True)
    
    
    def insert_to_ooc(
        keys: list,
        chunks: tuple[tuple[int, ...], ...],
        out: ArrayLike,
        name: str,
        *,
        lock: Lock | bool = True,
        region: tuple[slice, ...] | slice | None = None,
        return_stored: bool = False,
        load_stored: bool = False,
    ) -> dict:
        """
        Creates a Dask graph for storing chunks from ``arr`` in ``out``.
    
        Parameters
        ----------
        keys: list
            Dask keys of the input array
        chunks: tuple
            Dask chunks of the input array
        out: array-like
            Where to store results to
        name: str
            First element of dask keys
        lock: Lock-like or bool, optional
            Whether to lock or with what (default is ``True``,
            which means a :class:`threading.Lock` instance).
        region: slice-like, optional
            Where in ``out`` to store ``arr``'s results
            (default is ``None``, meaning all of ``out``).
        return_stored: bool, optional
            Whether to return ``out``
            (default is ``False``, meaning ``None`` is returned).
        load_stored: bool, optional
            Whether to handling loading from ``out`` at the same time.
            Ignored if ``return_stored`` is not ``True``.
            (default is ``False``, meaning defer to ``return_stored``).
    
        Returns
        -------
        dask graph of store operation
    
        Examples
        --------
        >>> import dask.array as da
        >>> d = da.ones((5, 6), chunks=(2, 3))
        >>> a = np.empty(d.shape)
        >>> insert_to_ooc(d.__dask_keys__(), d.chunks, a, "store-123")  # doctest: +SKIP
        """
    
        if lock is True:
            lock = Lock()
    
        slices = slices_from_chunks(chunks)
        if region:
            slices = [fuse_slice(region, slc) for slc in slices]
    
        if return_stored and load_stored:
            func = load_store_chunk
            args = (load_stored,)
        else:
            func = store_chunk  # type: ignore
            args = ()  # type: ignore
    
        dsk = {
            (name,) + t[1:]: (func, t, out, slc, lock, return_stored) + args
            for t, slc in zip(core.flatten(keys), slices)
        }
        return dsk
    
    
    def retrieve_from_ooc(
        keys: Collection[Key], dsk_pre: Graph, dsk_post: Graph
    ) -> dict[tuple, Any]:
        """
        Creates a Dask graph for loading stored ``keys`` from ``dsk``.
    
        Parameters
        ----------
        keys: Collection
            A sequence containing Dask graph keys to load
        dsk_pre: Mapping
            A Dask graph corresponding to a Dask Array before computation
        dsk_post: Mapping
            A Dask graph corresponding to a Dask Array after computation
    
        Examples
        --------
        >>> import dask.array as da
        >>> d = da.ones((5, 6), chunks=(2, 3))
        >>> a = np.empty(d.shape)
        >>> g = insert_to_ooc(d.__dask_keys__(), d.chunks, a, "store-123")
        >>> retrieve_from_ooc(g.keys(), g, {k: k for k in g.keys()})  # doctest: +SKIP
        """
        load_dsk = {
            ("load-" + k[0],) + k[1:]: (load_chunk, dsk_post[k]) + dsk_pre[k][3:-1]  # type: ignore
            for k in keys
        }
    
        return load_dsk
    
    
    def _as_dtype(a, dtype):
        if dtype is None:
            return a
        else:
            return a.astype(dtype)
    
    
    def asarray(
        a, allow_unknown_chunksizes=False, dtype=None, order=None, *, like=None, **kwargs
    ):
        """Convert the input to a dask array.
    
        Parameters
        ----------
        a : array-like
            Input data, in any form that can be converted to a dask array. This
            includes lists, lists of tuples, tuples, tuples of tuples, tuples of
            lists and ndarrays.
        allow_unknown_chunksizes: bool
            Allow unknown chunksizes, such as come from converting from dask
            dataframes.  Dask.array is unable to verify that chunks line up.  If
            data comes from differently aligned sources then this can cause
            unexpected results.
        dtype : data-type, optional
            By default, the data-type is inferred from the input data.
        order : {‘C’, ‘F’, ‘A’, ‘K’}, optional
            Memory layout. ‘A’ and ‘K’ depend on the order of input array a.
            ‘C’ row-major (C-style), ‘F’ column-major (Fortran-style) memory
            representation. ‘A’ (any) means ‘F’ if a is Fortran contiguous, ‘C’
            otherwise ‘K’ (keep) preserve input order. Defaults to ‘C’.
        like: array-like
            Reference object to allow the creation of Dask arrays with chunks
            that are not NumPy arrays. If an array-like passed in as ``like``
            supports the ``__array_function__`` protocol, the chunk type of the
            resulting array will be defined by it. In this case, it ensures the
            creation of a Dask array compatible with that passed in via this
            argument. If ``like`` is a Dask array, the chunk type of the
            resulting array will be defined by the chunk type of ``like``.
            Requires NumPy 1.20.0 or higher.
    
        Returns
        -------
        out : dask array
            Dask array interpretation of a.
    
        Examples
        --------
        >>> import dask.array as da
        >>> import numpy as np
        >>> x = np.arange(3)
        >>> da.asarray(x)
        dask.array<array, shape=(3,), dtype=int64, chunksize=(3,), chunktype=numpy.ndarray>
    
        >>> y = [[1, 2, 3], [4, 5, 6]]
        >>> da.asarray(y)
        dask.array<array, shape=(2, 3), dtype=int64, chunksize=(2, 3), chunktype=numpy.ndarray>
    
        .. warning::
            `order` is ignored if `a` is an `Array`, has the attribute ``to_dask_array``,
            or is a list or tuple of `Array`'s.
        """
        if like is None:
            if isinstance(a, Array):
                return _as_dtype(a, dtype)
            elif hasattr(a, "to_dask_array"):
                return _as_dtype(a.to_dask_array(), dtype)
            elif type(a).__module__.split(".")[0] == "xarray" and hasattr(a, "data"):
                return _as_dtype(asarray(a.data, order=order), dtype)
            elif isinstance(a, (list, tuple)) and any(isinstance(i, Array) for i in a):
                return _as_dtype(
                    stack(a, allow_unknown_chunksizes=allow_unknown_chunksizes), dtype
                )
            elif not isinstance(getattr(a, "shape", None), Iterable):
                a = np.asarray(a, dtype=dtype, order=order)
        else:
            like_meta = meta_from_array(like)
            if isinstance(a, Array):
                return a.map_blocks(np.asarray, like=like_meta, dtype=dtype, order=order)
            else:
                a = np.asarray(a, like=like_meta, dtype=dtype, order=order)
        return from_array(a, getitem=getter_inline, **kwargs)
    
    
    def asanyarray(a, dtype=None, order=None, *, like=None, inline_array=False):
        """Convert the input to a dask array.
    
        Subclasses of ``np.ndarray`` will be passed through as chunks unchanged.
    
        Parameters
        ----------
        a : array-like
            Input data, in any form that can be converted to a dask array. This
            includes lists, lists of tuples, tuples, tuples of tuples, tuples of
            lists and ndarrays.
        dtype : data-type, optional
            By default, the data-type is inferred from the input data.
        order : {‘C’, ‘F’, ‘A’, ‘K’}, optional
            Memory layout. ‘A’ and ‘K’ depend on the order of input array a.
            ‘C’ row-major (C-style), ‘F’ column-major (Fortran-style) memory
            representation. ‘A’ (any) means ‘F’ if a is Fortran contiguous, ‘C’
            otherwise ‘K’ (keep) preserve input order. Defaults to ‘C’.
        like: array-like
            Reference object to allow the creation of Dask arrays with chunks
            that are not NumPy arrays. If an array-like passed in as ``like``
            supports the ``__array_function__`` protocol, the chunk type of the
            resulting array will be defined by it. In this case, it ensures the
            creation of a Dask array compatible with that passed in via this
            argument. If ``like`` is a Dask array, the chunk type of the
            resulting array will be defined by the chunk type of ``like``.
            Requires NumPy 1.20.0 or higher.
        inline_array:
            Whether to inline the array in the resulting dask graph. For more information,
            see the documentation for ``dask.array.from_array()``.
    
        Returns
        -------
        out : dask array
            Dask array interpretation of a.
    
        Examples
        --------
        >>> import dask.array as da
        >>> import numpy as np
        >>> x = np.arange(3)
        >>> da.asanyarray(x)
        dask.array<array, shape=(3,), dtype=int64, chunksize=(3,), chunktype=numpy.ndarray>
    
        >>> y = [[1, 2, 3], [4, 5, 6]]
        >>> da.asanyarray(y)
        dask.array<array, shape=(2, 3), dtype=int64, chunksize=(2, 3), chunktype=numpy.ndarray>
    
        .. warning::
            `order` is ignored if `a` is an `Array`, has the attribute ``to_dask_array``,
            or is a list or tuple of `Array`'s.
        """
        if like is None:
            if isinstance(a, Array):
                return _as_dtype(a, dtype)
            elif hasattr(a, "to_dask_array"):
                return _as_dtype(a.to_dask_array(), dtype)
            elif type(a).__module__.split(".")[0] == "xarray" and hasattr(a, "data"):
                return _as_dtype(asarray(a.data, order=order), dtype)
            elif isinstance(a, (list, tuple)) and any(isinstance(i, Array) for i in a):
                return _as_dtype(stack(a), dtype)
            elif not isinstance(getattr(a, "shape", None), Iterable):
                a = np.asanyarray(a, dtype=dtype, order=order)
        else:
            like_meta = meta_from_array(like)
            if isinstance(a, Array):
                return a.map_blocks(np.asanyarray, like=like_meta, dtype=dtype, order=order)
            else:
                a = np.asanyarray(a, like=like_meta, dtype=dtype, order=order)
        return from_array(
            a,
            chunks=a.shape,
            getitem=getter_inline,
            asarray=False,
            inline_array=inline_array,
        )
    
    
    def is_scalar_for_elemwise(arg):
        """
    
        >>> is_scalar_for_elemwise(42)
        True
        >>> is_scalar_for_elemwise('foo')
        True
        >>> is_scalar_for_elemwise(True)
        True
        >>> is_scalar_for_elemwise(np.array(42))
        True
        >>> is_scalar_for_elemwise([1, 2, 3])
        True
        >>> is_scalar_for_elemwise(np.array([1, 2, 3]))
        False
        >>> is_scalar_for_elemwise(from_array(np.array(0), chunks=()))
        False
        >>> is_scalar_for_elemwise(np.dtype('i4'))
        True
        """
        # the second half of shape_condition is essentially just to ensure that
        # dask series / frame are treated as scalars in elemwise.
        maybe_shape = getattr(arg, "shape", None)
        shape_condition = not isinstance(maybe_shape, Iterable) or any(
            is_dask_collection(x) for x in maybe_shape
        )
    
        return (
            np.isscalar(arg)
            or shape_condition
            or isinstance(arg, np.dtype)
            or (isinstance(arg, np.ndarray) and arg.ndim == 0)
        )
    
    
    def broadcast_shapes(*shapes):
        """
        Determines output shape from broadcasting arrays.
    
        Parameters
        ----------
        shapes : tuples
            The shapes of the arguments.
    
        Returns
        -------
        output_shape : tuple
    
        Raises
        ------
        ValueError
            If the input shapes cannot be successfully broadcast together.
        """
        if len(shapes) == 1:
            return shapes[0]
        out = []
        for sizes in zip_longest(*map(reversed, shapes), fillvalue=-1):
            if np.isnan(sizes).any():
                dim = np.nan
            else:
                dim = 0 if 0 in sizes else np.max(sizes).item()
            if any(i not in [-1, 0, 1, dim] and not np.isnan(i) for i in sizes):
                raise ValueError(
                    "operands could not be broadcast together with "
                    "shapes {}".format(" ".join(map(str, shapes)))
                )
            out.append(dim)
        return tuple(reversed(out))
    
    
    def elemwise(op, *args, out=None, where=True, dtype=None, name=None, **kwargs):
        """Apply an elementwise ufunc-like function blockwise across arguments.
    
        Like numpy ufuncs, broadcasting rules are respected.
    
        Parameters
        ----------
        op : callable
            The function to apply. Should be numpy ufunc-like in the parameters
            that it accepts.
        *args : Any
            Arguments to pass to `op`. Non-dask array-like objects are first
            converted to dask arrays, then all arrays are broadcast together before
            applying the function blockwise across all arguments. Any scalar
            arguments are passed as-is following normal numpy ufunc behavior.
        out : dask array, optional
            If out is a dask.array then this overwrites the contents of that array
            with the result.
        where : array_like, optional
            An optional boolean mask marking locations where the ufunc should be
            applied. Can be a scalar, dask array, or any other array-like object.
            Mirrors the ``where`` argument to numpy ufuncs, see e.g. ``numpy.add``
            for more information.
        dtype : dtype, optional
            If provided, overrides the output array dtype.
        name : str, optional
            A unique key name to use when building the backing dask graph. If not
            provided, one will be automatically generated based on the input
            arguments.
    
        Examples
        --------
        >>> elemwise(add, x, y)  # doctest: +SKIP
        >>> elemwise(sin, x)  # doctest: +SKIP
        >>> elemwise(sin, x, out=dask_array)  # doctest: +SKIP
    
        See Also
        --------
        blockwise
        """
        if kwargs:
            raise TypeError(
                f"{op.__name__} does not take the following keyword arguments "
                f"{sorted(kwargs)}"
            )
    
        out = _elemwise_normalize_out(out)
        where = _elemwise_normalize_where(where)
        args = [np.asarray(a) if isinstance(a, (list, tuple)) else a for a in args]
    
        shapes = []
        for arg in args:
            shape = getattr(arg, "shape", ())
            if any(is_dask_collection(x) for x in shape):
                # Want to exclude Delayed shapes and dd.Scalar
                shape = ()
            shapes.append(shape)
        if isinstance(where, Array):
            shapes.append(where.shape)
        if isinstance(out, Array):
            shapes.append(out.shape)
    
        shapes = [s if isinstance(s, Iterable) else () for s in shapes]
        out_ndim = len(
            broadcast_shapes(*shapes)
        )  # Raises ValueError if dimensions mismatch
        expr_inds = tuple(range(out_ndim))[::-1]
    
        if dtype is not None:
            need_enforce_dtype = True
        else:
            # We follow NumPy's rules for dtype promotion, which special cases
            # scalars and 0d ndarrays (which it considers equivalent) by using
            # their values to compute the result dtype:
            # https://github.com/numpy/numpy/issues/6240
            # We don't inspect the values of 0d dask arrays, because these could
            # hold potentially very expensive calculations. Instead, we treat
            # them just like other arrays, and if necessary cast the result of op
            # to match.
            vals = [
                (
                    np.empty((1,) * max(1, a.ndim), dtype=a.dtype)
                    if not is_scalar_for_elemwise(a)
                    else a
                )
                for a in args
            ]
            try:
                dtype = apply_infer_dtype(op, vals, {}, "elemwise", suggest_dtype=False)
            except Exception:
                return NotImplemented
            need_enforce_dtype = any(
                not is_scalar_for_elemwise(a) and a.ndim == 0 for a in args
            )
    
        if not name:
            name = f"{funcname(op)}-{tokenize(op, dtype, *args, where)}"
    
        blockwise_kwargs = dict(dtype=dtype, name=name, token=funcname(op).strip("_"))
    
        if where is not True:
            blockwise_kwargs["elemwise_where_function"] = op
            op = _elemwise_handle_where
            args.extend([where, out])
    
        if need_enforce_dtype:
            blockwise_kwargs["enforce_dtype"] = dtype
            blockwise_kwargs["enforce_dtype_function"] = op
            op = _enforce_dtype
    
        result = blockwise(
            op,
            expr_inds,
            *concat(
                (a, tuple(range(a.ndim)[::-1]) if not is_scalar_for_elemwise(a) else None)
                for a in args
            ),
            **blockwise_kwargs,
        )
    
        return handle_out(out, result)
    
    
    def _elemwise_normalize_where(where):
        if where is True:
            return True
        elif where is False or where is None:
            return False
        return asarray(where)
    
    
    def _elemwise_handle_where(*args, **kwargs):
        function = kwargs.pop("elemwise_where_function")
        *args, where, out = args
        if hasattr(out, "copy"):
            out = out.copy()
        return function(*args, where=where, out=out, **kwargs)
    
    
    def _elemwise_normalize_out(out):
        if isinstance(out, tuple):
            if len(out) == 1:
                out = out[0]
            elif len(out) > 1:
                raise NotImplementedError("The out parameter is not fully supported")
            else:
                out = None
        if not (out is None or isinstance(out, Array)):
            raise NotImplementedError(
                f"The out parameter is not fully supported."
                f" Received type {type(out).__name__}, expected Dask Array"
            )
        return out
    
    
    def handle_out(out, result):
        """Handle out parameters
    
        If out is a dask.array then this overwrites the contents of that array with
        the result
        """
        out = _elemwise_normalize_out(out)
        if isinstance(out, Array):
            if out.shape != result.shape:
                raise ValueError(
                    "Mismatched shapes between result and out parameter. "
                    "out=%s, result=%s" % (str(out.shape), str(result.shape))
                )
            out._chunks = result.chunks
            out.dask = result.dask
            out._meta = result._meta
            out._name = result.name
            return out
        else:
            return result
    
    
    def _enforce_dtype(*args, **kwargs):
        """Calls a function and converts its result to the given dtype.
    
        The parameters have deliberately been given unwieldy names to avoid
        clashes with keyword arguments consumed by blockwise
    
        A dtype of `object` is treated as a special case and not enforced,
        because it is used as a dummy value in some places when the result will
        not be a block in an Array.
    
        Parameters
        ----------
        enforce_dtype : dtype
            Result dtype
        enforce_dtype_function : callable
            The wrapped function, which will be passed the remaining arguments
        """
        dtype = kwargs.pop("enforce_dtype")
        function = kwargs.pop("enforce_dtype_function")
    
        result = function(*args, **kwargs)
        if hasattr(result, "dtype") and dtype != result.dtype and dtype != object:
            if not np.can_cast(result, dtype, casting="same_kind"):
                raise ValueError(
                    "Inferred dtype from function %r was %r "
                    "but got %r, which can't be cast using "
                    "casting='same_kind'"
                    % (funcname(function), str(dtype), str(result.dtype))
                )
            if np.isscalar(result):
                # scalar astype method doesn't take the keyword arguments, so
                # have to convert via 0-dimensional array and back.
                result = result.astype(dtype)
            else:
                try:
                    result = result.astype(dtype, copy=False)
                except TypeError:
                    # Missing copy kwarg
                    result = result.astype(dtype)
        return result
    
    
    def broadcast_to(x, shape, chunks=None, meta=None):
        """Broadcast an array to a new shape.
    
        Parameters
        ----------
        x : array_like
            The array to broadcast.
        shape : tuple
            The shape of the desired array.
        chunks : tuple, optional
            If provided, then the result will use these chunks instead of the same
            chunks as the source array. Setting chunks explicitly as part of
            broadcast_to is more efficient than rechunking afterwards. Chunks are
            only allowed to differ from the original shape along dimensions that
            are new on the result or have size 1 the input array.
        meta : empty ndarray
            empty ndarray created with same NumPy backend, ndim and dtype as the
            Dask Array being created (overrides dtype)
    
        Returns
        -------
        broadcast : dask array
    
        See Also
        --------
        :func:`numpy.broadcast_to`
        """
        x = asarray(x)
        shape = tuple(shape)
    
        if meta is None:
            meta = meta_from_array(x)
    
        if x.shape == shape and (chunks is None or chunks == x.chunks):
            return x
    
        ndim_new = len(shape) - x.ndim
        if ndim_new < 0 or any(
            new != old for new, old in zip(shape[ndim_new:], x.shape) if old != 1
        ):
            raise ValueError(f"cannot broadcast shape {x.shape} to shape {shape}")
    
        if chunks is None:
            chunks = tuple((s,) for s in shape[:ndim_new]) + tuple(
                bd if old > 1 else (new,)
                for bd, old, new in zip(x.chunks, x.shape, shape[ndim_new:])
            )
        else:
            chunks = normalize_chunks(
                chunks, shape, dtype=x.dtype, previous_chunks=x.chunks
            )
            for old_bd, new_bd in zip(x.chunks, chunks[ndim_new:]):
                if old_bd != new_bd and old_bd != (1,):
                    raise ValueError(
                        "cannot broadcast chunks %s to chunks %s: "
                        "new chunks must either be along a new "
                        "dimension or a dimension of size 1" % (x.chunks, chunks)
                    )
    
        name = "broadcast_to-" + tokenize(x, shape, chunks)
        dsk = {}
    
        enumerated_chunks = product(*(enumerate(bds) for bds in chunks))
        for new_index, chunk_shape in (zip(*ec) for ec in enumerated_chunks):
            old_index = tuple(
                0 if bd == (1,) else i for bd, i in zip(x.chunks, new_index[ndim_new:])
            )
            old_key = (x.name,) + old_index
            new_key = (name,) + new_index
            dsk[new_key] = (np.broadcast_to, old_key, quote(chunk_shape))
    
        graph = HighLevelGraph.from_collections(name, dsk, dependencies=[x])
        return Array(graph, name, chunks, dtype=x.dtype, meta=meta)
    
    
    @derived_from(np)
    def broadcast_arrays(*args, subok=False):
        subok = bool(subok)
    
        to_array = asanyarray if subok else asarray
        args = tuple(to_array(e) for e in args)
    
        # Unify uneven chunking
        inds = [list(reversed(range(x.ndim))) for x in args]
        uc_args = concat(zip(args, inds))
        _, args = unify_chunks(*uc_args, warn=False)
    
        shape = broadcast_shapes(*(e.shape for e in args))
        chunks = broadcast_chunks(*(e.chunks for e in args))
    
        if NUMPY_GE_200:
            result = tuple(broadcast_to(e, shape=shape, chunks=chunks) for e in args)
        else:
            result = [broadcast_to(e, shape=shape, chunks=chunks) for e in args]
    
        return result
    
    
    def offset_func(func, offset, *args):
        """Offsets inputs by offset
    
        >>> double = lambda x: x * 2
        >>> f = offset_func(double, (10,))
        >>> f(1)
        22
        >>> f(300)
        620
        """
    
        def _offset(*args):
            args2 = list(map(add, args, offset))
            return func(*args2)
    
        with contextlib.suppress(Exception):
            _offset.__name__ = "offset_" + func.__name__
    
        return _offset
    
    
    def chunks_from_arrays(arrays):
        """Chunks tuple from nested list of arrays
    
        >>> x = np.array([1, 2])
        >>> chunks_from_arrays([x, x])
        ((2, 2),)
    
        >>> x = np.array([[1, 2]])
        >>> chunks_from_arrays([[x], [x]])
        ((1, 1), (2,))
    
        >>> x = np.array([[1, 2]])
        >>> chunks_from_arrays([[x, x]])
        ((1,), (2, 2))
    
        >>> chunks_from_arrays([1, 1])
        ((1, 1),)
        """
        if not arrays:
            return ()
        result = []
        dim = 0
    
        def shape(x):
            try:
                return x.shape if x.shape else (1,)
            except AttributeError:
                return (1,)
    
        while isinstance(arrays, (list, tuple)):
>           result.append(tuple(shape(deepfirst(a))[dim] for a in arrays))
E           IndexError: tuple index out of range

../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/array/core.py:5237: IndexError

Check warning on line 0 in distributed.tests.test_stress

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

8 out of 9 runs failed: test_stress_creation_and_deletion (distributed.tests.test_stress)

artifacts/macos-latest-3.12-default-ci1/pytest.xml [took 7s]
artifacts/ubuntu-latest-3.10-default-ci1/pytest.xml [took 7s]
artifacts/ubuntu-latest-3.10-no_expr-ci1/pytest.xml [took 7s]
artifacts/ubuntu-latest-3.10-no_queue-ci1/pytest.xml [took 7s]
artifacts/ubuntu-latest-3.11-default-ci1/pytest.xml [took 7s]
artifacts/ubuntu-latest-3.12-default-ci1/pytest.xml [took 8s]
artifacts/ubuntu-latest-mindeps-numpy-ci1/pytest.xml [took 7s]
artifacts/ubuntu-latest-mindeps-pandas-ci1/pytest.xml [took 7s]
Raw output
AssertionError: assert 'round-c642c375b3c4b5b78c9119e82165ed0a' == 8000884.93
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:33919', workers: 0, cores: 0, tasks: 0>

    @pytest.mark.slow
    @gen_cluster(
        nthreads=[],
        client=True,
        scheduler_kwargs={"allowed_failures": 100_000},
    )
    async def test_stress_creation_and_deletion(c, s):
        # Assertions are handled by the validate mechanism in the scheduler
        pytest.importorskip("numpy")
        da = pytest.importorskip("dask.array")
    
        rng = da.random.RandomState(0)
        x = rng.random(size=(2000, 2000), chunks=(100, 100))
        y = ((x + 1).T + (x * 2) - x.mean(axis=1)).sum().round(2)
        z = c.persist(y)
    
        async def create_and_destroy_worker(delay):
            start = time()
            while time() < start + 5:
                async with Worker(s.address, nthreads=2) as n:
                    await asyncio.sleep(delay)
    
        await asyncio.gather(*(create_and_destroy_worker(0.1 * i) for i in range(20)))
    
        async with Worker(s.address, nthreads=2):
>           assert await c.compute(z) == 8000884.93
E           AssertionError: assert 'round-c642c375b3c4b5b78c9119e82165ed0a' == 8000884.93

distributed/tests/test_stress.py:125: AssertionError

Check warning on line 0 in distributed.comm.tests.test_comms

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

1 out of 7 runs failed: test_tls_comm_closed_implicit[tornado] (distributed.comm.tests.test_comms)

artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
Raw output
ssl.SSLError: [SYS] unknown error (_ssl.c:2580)
tcp = <module 'distributed.comm.tcp' from '/home/runner/work/distributed/distributed/distributed/comm/tcp.py'>

    @gen_test()
    async def test_tls_comm_closed_implicit(tcp):
>       await check_comm_closed_implicit("tls://127.0.0.1", **tls_kwargs)

distributed/comm/tests/test_comms.py:777: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/comm/tests/test_comms.py:763: in check_comm_closed_implicit
    await comm.read()
distributed/comm/tcp.py:225: in read
    frames_nosplit_nbytes_bin = await stream.read_bytes(fmt_size)
../../../miniconda3/envs/dask-distributed/lib/python3.11/site-packages/tornado/iostream.py:422: in read_bytes
    self._try_inline_read()
../../../miniconda3/envs/dask-distributed/lib/python3.11/site-packages/tornado/iostream.py:836: in _try_inline_read
    pos = self._read_to_buffer_loop()
../../../miniconda3/envs/dask-distributed/lib/python3.11/site-packages/tornado/iostream.py:750: in _read_to_buffer_loop
    if self._read_to_buffer() == 0:
../../../miniconda3/envs/dask-distributed/lib/python3.11/site-packages/tornado/iostream.py:861: in _read_to_buffer
    bytes_read = self.read_from_fd(buf)
../../../miniconda3/envs/dask-distributed/lib/python3.11/site-packages/tornado/iostream.py:1552: in read_from_fd
    return self.socket.recv_into(buf, len(buf))
../../../miniconda3/envs/dask-distributed/lib/python3.11/ssl.py:1314: in recv_into
    return self.read(nbytes, buffer)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <ssl.SSLSocket [closed] fd=-1, family=2, type=1, proto=0>, len = 65536
buffer = bytearray(b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x...0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')

    def read(self, len=1024, buffer=None):
        """Read up to LEN bytes and return them.
        Return zero-length string on EOF."""
    
        self._checkClosed()
        if self._sslobj is None:
            raise ValueError("Read on closed or unwrapped SSL socket.")
        try:
            if buffer is not None:
>               return self._sslobj.read(len, buffer)
E               ssl.SSLError: [SYS] unknown error (_ssl.c:2580)

../../../miniconda3/envs/dask-distributed/lib/python3.11/ssl.py:1166: SSLError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

All 6 runs failed: test_rechunk_configuration[p2p-tasks] (distributed.shuffle.tests.test_rechunk)

artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 39810932bc72623ea866af170a896ed5 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '39810932bc72623ea866af170a896ed5'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…wait self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:42801', workers: 0, cores: 0, tasks: 0>
config_value = 'tasks', keyword = 'p2p'
ws = (<Worker 'tcp://127.0.0.1:39467', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:35233', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.54100301, 0.39676024, 0.80875911, 0.57210751, 0.49079247,
        0.98476684, 0.55330331, 0.02055366, 0.4693..., 0.57304034, 0.71045835, 0.80842066, 0.66522237,
        0.10320915, 0.77049627, 0.21466926, 0.95670238, 0.30806249]])

    @pytest.mark.parametrize("config_value", ["tasks", "p2p", None])
    @pytest.mark.parametrize("keyword", ["tasks", "p2p", None])
    @gen_cluster(client=True)
    async def test_rechunk_configuration(c, s, *ws, config_value, keyword):
        """Try rechunking a random 1d matrix
    
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_1d
        """
        a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
        x = da.from_array(a, chunks=(10, 1))
        new = ((1,) * 10, (10,))
        config = {"array.rechunk.method": config_value} if config_value is not None else {}
        with dask.config.set(config):
            x2 = rechunk(x, chunks=new, method=keyword)
        expected_algorithm = keyword if keyword is not None else config_value
        if expected_algorithm == "p2p":
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
        elif expected_algorithm == "tasks":
            assert not any(
                key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
            )
        # Neither is specified, so we choose the best one (see test_rechunk_heuristic for a full test of the heuristic)
        else:
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
    
        assert x2.chunks == new
>       assert np.all(await c.compute(x2) == a)

distributed/shuffle/tests/test_rechunk.py:210: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 39810932bc72623ea866af170a896ed5 failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

All 6 runs failed: test_rechunk_configuration[p2p-p2p] (distributed.shuffle.tests.test_rechunk)

artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P e07f19a3ad28b5edecde97f5c6e27492 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: 'e07f19a3ad28b5edecde97f5c6e27492'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and… await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:33805', workers: 0, cores: 0, tasks: 0>
config_value = 'p2p', keyword = 'p2p'
ws = (<Worker 'tcp://127.0.0.1:46259', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:43611', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.64399807, 0.31063259, 0.17543577, 0.82783572, 0.74638352,
        0.19865671, 0.16140169, 0.6856798 , 0.0280..., 0.60435229, 0.88620013, 0.03875494, 0.32640197,
        0.12674629, 0.42828326, 0.19082475, 0.71912404, 0.68470373]])

    @pytest.mark.parametrize("config_value", ["tasks", "p2p", None])
    @pytest.mark.parametrize("keyword", ["tasks", "p2p", None])
    @gen_cluster(client=True)
    async def test_rechunk_configuration(c, s, *ws, config_value, keyword):
        """Try rechunking a random 1d matrix
    
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_1d
        """
        a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
        x = da.from_array(a, chunks=(10, 1))
        new = ((1,) * 10, (10,))
        config = {"array.rechunk.method": config_value} if config_value is not None else {}
        with dask.config.set(config):
            x2 = rechunk(x, chunks=new, method=keyword)
        expected_algorithm = keyword if keyword is not None else config_value
        if expected_algorithm == "p2p":
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
        elif expected_algorithm == "tasks":
            assert not any(
                key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
            )
        # Neither is specified, so we choose the best one (see test_rechunk_heuristic for a full test of the heuristic)
        else:
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
    
        assert x2.chunks == new
>       assert np.all(await c.compute(x2) == a)

distributed/shuffle/tests/test_rechunk.py:210: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P e07f19a3ad28b5edecde97f5c6e27492 failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

All 6 runs failed: test_rechunk_configuration[p2p-None] (distributed.shuffle.tests.test_rechunk)

artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P fe1d1aff4e980bf6369e4d4adf98f81b failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: 'fe1d1aff4e980bf6369e4d4adf98f81b'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…= await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:44351', workers: 0, cores: 0, tasks: 0>
config_value = None, keyword = 'p2p'
ws = (<Worker 'tcp://127.0.0.1:42471', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:41635', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.05373005, 0.04687012, 0.18475967, 0.63853183, 0.37113315,
        0.34780472, 0.59389736, 0.66525719, 0.1175..., 0.43862109, 0.26487191, 0.08598918, 0.65790677,
        0.24254732, 0.67961834, 0.27260373, 0.6418115 , 0.75985999]])

    @pytest.mark.parametrize("config_value", ["tasks", "p2p", None])
    @pytest.mark.parametrize("keyword", ["tasks", "p2p", None])
    @gen_cluster(client=True)
    async def test_rechunk_configuration(c, s, *ws, config_value, keyword):
        """Try rechunking a random 1d matrix
    
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_1d
        """
        a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
        x = da.from_array(a, chunks=(10, 1))
        new = ((1,) * 10, (10,))
        config = {"array.rechunk.method": config_value} if config_value is not None else {}
        with dask.config.set(config):
            x2 = rechunk(x, chunks=new, method=keyword)
        expected_algorithm = keyword if keyword is not None else config_value
        if expected_algorithm == "p2p":
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
        elif expected_algorithm == "tasks":
            assert not any(
                key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
            )
        # Neither is specified, so we choose the best one (see test_rechunk_heuristic for a full test of the heuristic)
        else:
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
    
        assert x2.chunks == new
>       assert np.all(await c.compute(x2) == a)

distributed/shuffle/tests/test_rechunk.py:210: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P fe1d1aff4e980bf6369e4d4adf98f81b failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

All 6 runs failed: test_rechunk_configuration[None-p2p] (distributed.shuffle.tests.test_rechunk)

artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P aad191a3b35f209c6fd930cfec08e9c4 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: 'aad191a3b35f209c6fd930cfec08e9c4'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…= await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:32987', workers: 0, cores: 0, tasks: 0>
config_value = 'p2p', keyword = None
ws = (<Worker 'tcp://127.0.0.1:43423', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:40799', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.6420339 , 0.83288408, 0.4242612 , 0.40240756, 0.42555848,
        0.10922281, 0.28362424, 0.53933638, 0.4652..., 0.49628131, 0.43796726, 0.73645182, 0.75935205,
        0.0463592 , 0.93290705, 0.51500255, 0.631823  , 0.02839664]])

    @pytest.mark.parametrize("config_value", ["tasks", "p2p", None])
    @pytest.mark.parametrize("keyword", ["tasks", "p2p", None])
    @gen_cluster(client=True)
    async def test_rechunk_configuration(c, s, *ws, config_value, keyword):
        """Try rechunking a random 1d matrix
    
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_1d
        """
        a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
        x = da.from_array(a, chunks=(10, 1))
        new = ((1,) * 10, (10,))
        config = {"array.rechunk.method": config_value} if config_value is not None else {}
        with dask.config.set(config):
            x2 = rechunk(x, chunks=new, method=keyword)
        expected_algorithm = keyword if keyword is not None else config_value
        if expected_algorithm == "p2p":
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
        elif expected_algorithm == "tasks":
            assert not any(
                key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
            )
        # Neither is specified, so we choose the best one (see test_rechunk_heuristic for a full test of the heuristic)
        else:
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
    
        assert x2.chunks == new
>       assert np.all(await c.compute(x2) == a)

distributed/shuffle/tests/test_rechunk.py:210: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P aad191a3b35f209c6fd930cfec08e9c4 failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

All 6 runs failed: test_rechunk_configuration[None-None] (distributed.shuffle.tests.test_rechunk)

artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P ed041f685e395fded8871501929b9a5e failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: 'ed041f685e395fded8871501929b9a5e'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and… = await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:42283', workers: 0, cores: 0, tasks: 0>
config_value = None, keyword = None
ws = (<Worker 'tcp://127.0.0.1:38627', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:40285', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.83001867, 0.80178537, 0.23115057, 0.22929814, 0.05555555,
        0.53942166, 0.84898209, 0.27603218, 0.0611..., 0.75310379, 0.19612405, 0.08756546, 0.0405993 ,
        0.66275933, 0.59495111, 0.77384036, 0.72799921, 0.69120823]])

    @pytest.mark.parametrize("config_value", ["tasks", "p2p", None])
    @pytest.mark.parametrize("keyword", ["tasks", "p2p", None])
    @gen_cluster(client=True)
    async def test_rechunk_configuration(c, s, *ws, config_value, keyword):
        """Try rechunking a random 1d matrix
    
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_1d
        """
        a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
        x = da.from_array(a, chunks=(10, 1))
        new = ((1,) * 10, (10,))
        config = {"array.rechunk.method": config_value} if config_value is not None else {}
        with dask.config.set(config):
            x2 = rechunk(x, chunks=new, method=keyword)
        expected_algorithm = keyword if keyword is not None else config_value
        if expected_algorithm == "p2p":
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
        elif expected_algorithm == "tasks":
            assert not any(
                key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
            )
        # Neither is specified, so we choose the best one (see test_rechunk_heuristic for a full test of the heuristic)
        else:
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
    
        assert x2.chunks == new
>       assert np.all(await c.compute(x2) == a)

distributed/shuffle/tests/test_rechunk.py:210: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P ed041f685e395fded8871501929b9a5e failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

All 6 runs failed: test_rechunk_heuristic[new0-p2p] (distributed.shuffle.tests.test_rechunk)

artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 4eafd5a6ec56ad19493d1adddac5f652 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '4eafd5a6ec56ad19493d1adddac5f652'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…distributed/shuffle/_worker_plugin.py:411: in get_or_create_shuffle
    return sync(
distributed/utils.py:439: in sync
    raise error
distributed/utils.py:413: in f
    result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/tornado/gen.py:766: in run
    value = future.result()
distributed/shuffle/_worker_plugin.py:145: in get_or_create
    shuffle_run = await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:41493', workers: 0, cores: 0, tasks: 0>
a = array([[0.37151424, 0.96300086, 0.03572343, ..., 0.13664586, 0.79334727,
        0.45763255],
       [0.6307737 , 0.37...21,
        0.63065886],
       [0.67865086, 0.35168033, 0.19536338, ..., 0.7557135 , 0.97195412,
        0.57730087]])
b = <Worker 'tcp://127.0.0.1:39939', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>
new = ((1, 1, 1, 1, 1, 1, ...), (100,)), expected_algorithm = 'p2p'

    @pytest.mark.parametrize(
        ["new", "expected_algorithm"],
        [
            # All-to-all rechunking defaults to P2P
            (((1,) * 100, (100,)), "p2p"),
            # Localized rechunking defaults to tasks
            (((50, 50), (2,) * 50), "tasks"),
            # Less local rechunking first defaults to tasks,
            (((25, 25, 25, 25), (4,) * 25), "tasks"),
            # then switches to p2p
            (((10,) * 10, (10,) * 10), "p2p"),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_heuristic(c, s, a, b, new, expected_algorithm):
        a = np.random.default_rng().uniform(0, 1, 10000).reshape((100, 100))
        x = da.from_array(a, chunks=(100, 1))
        x2 = rechunk(x, chunks=new)
        if expected_algorithm == "p2p":
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
        else:
            assert not any(
                key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
            )
    
        assert x2.chunks == new
>       assert np.all(await c.compute(x2) == a)

distributed/shuffle/tests/test_rechunk.py:239: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 4eafd5a6ec56ad19493d1adddac5f652 failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

All 6 runs failed: test_rechunk_heuristic[new3-p2p] (distributed.shuffle.tests.test_rechunk)

artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P c806e7671339aecf80206effa5e88ca8 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: 'c806e7671339aecf80206effa5e88ca8'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…plugin.py:411: in get_or_create_shuffle
    return sync(
distributed/utils.py:439: in sync
    raise error
distributed/utils.py:413: in f
    result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/tornado/gen.py:766: in run
    value = future.result()
distributed/shuffle/_worker_plugin.py:145: in get_or_create
    shuffle_run = await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:33641', workers: 0, cores: 0, tasks: 0>
a = array([[0.77218519, 0.98403101, 0.73971916, ..., 0.64295602, 0.86854558,
        0.98601708],
       [0.09089845, 0.03...8 ,
        0.69222466],
       [0.63829716, 0.61898815, 0.94981538, ..., 0.25538011, 0.30019752,
        0.31170137]])
b = <Worker 'tcp://127.0.0.1:41547', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>
new = ((10, 10, 10, 10, 10, 10, ...), (10, 10, 10, 10, 10, 10, ...))
expected_algorithm = 'p2p'

    @pytest.mark.parametrize(
        ["new", "expected_algorithm"],
        [
            # All-to-all rechunking defaults to P2P
            (((1,) * 100, (100,)), "p2p"),
            # Localized rechunking defaults to tasks
            (((50, 50), (2,) * 50), "tasks"),
            # Less local rechunking first defaults to tasks,
            (((25, 25, 25, 25), (4,) * 25), "tasks"),
            # then switches to p2p
            (((10,) * 10, (10,) * 10), "p2p"),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_heuristic(c, s, a, b, new, expected_algorithm):
        a = np.random.default_rng().uniform(0, 1, 10000).reshape((100, 100))
        x = da.from_array(a, chunks=(100, 1))
        x2 = rechunk(x, chunks=new)
        if expected_algorithm == "p2p":
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
        else:
            assert not any(
                key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
            )
    
        assert x2.chunks == new
>       assert np.all(await c.compute(x2) == a)

distributed/shuffle/tests/test_rechunk.py:239: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P c806e7671339aecf80206effa5e88ca8 failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

All 6 runs failed: test_cull_p2p_rechunk_independent_partitions (distributed.shuffle.tests.test_rechunk)

artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
assert 57 < (228 / 4)
 +  where 57 = len(HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7f948f9dc5e0>\n 0. getitem-d1db8f0a6402f938465bc6034b01ed58\n)
 +    where HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7f948f9dc5e0>\n 0. getitem-d1db8f0a6402f938465bc6034b01ed58\n = dask.array<getitem, shape=(5, 2, 10), dtype=float64, chunksize=(5, 1, 10), chunktype=numpy.ndarray>.dask
 +  and   228 = len(HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7f948e19f490>\n 0. rechunk-p2p-620b9aabeabeb94b707695b0db4eb07a\n)
 +    where HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7f948e19f490>\n 0. rechunk-p2p-620b9aabeabeb94b707695b0db4eb07a\n = dask.array<rechunk-p2p, shape=(10, 10, 10), dtype=float64, chunksize=(5, 1, 10), chunktype=numpy.ndarray>.dask
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:36421', workers: 0, cores: 0, tasks: 0>
ws = (<Worker 'tcp://127.0.0.1:45917', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:36767', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[[0.32911429, 0.67956124, 0.09958113, 0.31750314, 0.54180337,
         0.42813737, 0.07759511, 0.17287479, 0.19...0.91310963, 0.59520114, 0.99521652, 0.35626236,
         0.01563698, 0.81999609, 0.8769825 , 0.71105292, 0.33717714]]])
x = dask.array<array, shape=(10, 10, 10), dtype=float64, chunksize=(1, 5, 1), chunktype=numpy.ndarray>
new = (5, 1, -1)

    @gen_cluster(client=True)
    async def test_cull_p2p_rechunk_independent_partitions(c, s, *ws):
        a = np.random.default_rng().uniform(0, 1, 1000).reshape((10, 10, 10))
        x = da.from_array(a, chunks=(1, 5, 1))
        new = (5, 1, -1)
        rechunked = rechunk(x, chunks=new, method="p2p")
        (dsk,) = dask.optimize(rechunked)
        culled = rechunked[:5, :2]
        (dsk_culled,) = dask.optimize(culled)
    
        # The culled graph requires only 1/2 of the input tasks
        n_inputs = len(
            [1 for key in dsk.dask.get_all_dependencies() if key[0].startswith("array-")]
        )
        n_culled_inputs = len(
            [
                1
                for key in dsk_culled.dask.get_all_dependencies()
                if key[0].startswith("array-")
            ]
        )
        assert n_culled_inputs == n_inputs / 4
        # The culled graph should also have less than 1/4 the tasks
>       assert len(dsk_culled.dask) < len(dsk.dask) / 4
E       assert 57 < (228 / 4)
E        +  where 57 = len(HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7f948f9dc5e0>\n 0. getitem-d1db8f0a6402f938465bc6034b01ed58\n)
E        +    where HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7f948f9dc5e0>\n 0. getitem-d1db8f0a6402f938465bc6034b01ed58\n = dask.array<getitem, shape=(5, 2, 10), dtype=float64, chunksize=(5, 1, 10), chunktype=numpy.ndarray>.dask
E        +  and   228 = len(HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7f948e19f490>\n 0. rechunk-p2p-620b9aabeabeb94b707695b0db4eb07a\n)
E        +    where HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7f948e19f490>\n 0. rechunk-p2p-620b9aabeabeb94b707695b0db4eb07a\n = dask.array<rechunk-p2p, shape=(10, 10, 10), dtype=float64, chunksize=(5, 1, 10), chunktype=numpy.ndarray>.dask

distributed/shuffle/tests/test_rechunk.py:265: AssertionError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

All 6 runs failed: test_rechunk_expand (distributed.shuffle.tests.test_rechunk)

artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P bbd47b3374e9ae0063a0808c2bcfb35f failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: 'bbd47b3374e9ae0063a0808c2bcfb35f'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…r
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
>           yield

distributed/shuffle/_core.py:523: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/shuffle/_rechunk.py:170: in rechunk_transfer
    return get_worker_plugin().add_partition(
distributed/shuffle/_worker_plugin.py:348: in add_partition
    shuffle_run = self.get_or_create_shuffle(id)
distributed/shuffle/_worker_plugin.py:411: in get_or_create_shuffle
    return sync(
distributed/utils.py:439: in sync
    raise error
distributed/utils.py:413: in f
    result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/tornado/gen.py:766: in run
    value = future.result()
distributed/shuffle/_worker_plugin.py:145: in get_or_create
    shuffle_run = await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:33125', workers: 0, cores: 0, tasks: 0>
ws = (<Worker 'tcp://127.0.0.1:36817', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:44325', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.20961821, 0.03813439, 0.16031038, 0.83008987, 0.47724403,
        0.15157882, 0.65182507, 0.81244842, 0.2439..., 0.22958412, 0.27287624, 0.66399029, 0.51144192,
        0.08086217, 0.82881384, 0.60756522, 0.08527088, 0.51785269]])
x = dask.array<array, shape=(10, 10), dtype=float64, chunksize=(5, 5), chunktype=numpy.ndarray>
y = dask.array<rechunk-p2p, shape=(10, 10), dtype=float64, chunksize=(3, 3), chunktype=numpy.ndarray>

    @gen_cluster(client=True)
    async def test_rechunk_expand(c, s, *ws):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_expand
        """
        a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
        x = da.from_array(a, chunks=(5, 5))
        y = x.rechunk(chunks=((3, 3, 3, 1), (3, 3, 3, 1)), method="p2p")
>       assert np.all(await c.compute(y) == a)

distributed/shuffle/tests/test_rechunk.py:377: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P bbd47b3374e9ae0063a0808c2bcfb35f failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

All 6 runs failed: test_rechunk_expand2 (distributed.shuffle.tests.test_rechunk)

artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P b07060492e82fee76856c8bc3b212713 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: 'b07060492e82fee76856c8bc3b212713'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…rs(id: ShuffleId) -> Iterator[None]:
        try:
>           yield

distributed/shuffle/_core.py:523: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/shuffle/_rechunk.py:170: in rechunk_transfer
    return get_worker_plugin().add_partition(
distributed/shuffle/_worker_plugin.py:348: in add_partition
    shuffle_run = self.get_or_create_shuffle(id)
distributed/shuffle/_worker_plugin.py:411: in get_or_create_shuffle
    return sync(
distributed/utils.py:439: in sync
    raise error
distributed/utils.py:413: in f
    result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/tornado/gen.py:766: in run
    value = future.result()
distributed/shuffle/_worker_plugin.py:145: in get_or_create
    shuffle_run = await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:44097', workers: 0, cores: 0, tasks: 0>
ws = (<Worker 'tcp://127.0.0.1:38371', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:38175', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = 3, b = 2
orig = array([[0.6017826 , 0.45461851, 0.95201606],
       [0.20078242, 0.9710892 , 0.78108917],
       [0.66423832, 0.27980603, 0.07144657]])

    @gen_cluster(client=True)
    async def test_rechunk_expand2(c, s, *ws):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_expand2
        """
        (a, b) = (3, 2)
        orig = np.random.default_rng().uniform(0, 1, a**b).reshape((a,) * b)
        for off, off2 in product(range(1, a - 1), range(1, a - 1)):
            old = ((a - off, off),) * b
            x = da.from_array(orig, chunks=old)
            new = ((a - off2, off2),) * b
            assert np.all(await c.compute(x.rechunk(chunks=new, method="p2p")) == orig)
            if a - off - off2 > 0:
                new = ((off, a - off2 - off, off2),) * b
>               y = await c.compute(x.rechunk(chunks=new, method="p2p"))

distributed/shuffle/tests/test_rechunk.py:396: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P b07060492e82fee76856c8bc3b212713 failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

5 out of 6 runs failed: test_rechunk_unknown_from_pandas (distributed.shuffle.tests.test_rechunk)

artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P d30e37775e332baa036f77fc7f4c89c6 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: 'd30e37775e332baa036f77fc7f4c89c6'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…n.py:411: in get_or_create_shuffle
    return sync(
distributed/utils.py:439: in sync
    raise error
distributed/utils.py:413: in f
    result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/tornado/gen.py:766: in run
    value = future.result()
distributed/shuffle/_worker_plugin.py:145: in get_or_create
    shuffle_run = await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:35243', workers: 0, cores: 0, tasks: 0>
ws = (<Worker 'tcp://127.0.0.1:36087', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:42641', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
pd = <module 'pandas' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.10/site-packages/pandas/__init__.py'>
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/dataframe/__init__.py'>
arr = array([[-2.65340026e-02, -6.62908051e-01, -3.74405314e-01,
         1.62517681e+00,  4.22948452e-01,  1.14427358e+00,
... 1.24679262e+00,  5.95299989e-01,
         3.79757690e-01, -4.28536040e-01, -8.06355440e-01,
         1.46740310e+00]])

    @gen_cluster(client=True)
    async def test_rechunk_unknown_from_pandas(c, s, *ws):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_unknown_from_pandas
        """
        pd = pytest.importorskip("pandas")
        dd = pytest.importorskip("dask.dataframe")
    
        arr = np.random.default_rng().standard_normal((50, 10))
        x = dd.from_pandas(pd.DataFrame(arr), 2).values
        result = x.rechunk((None, (5, 5)), method="p2p")
        assert np.isnan(x.chunks[0]).all()
        assert np.isnan(result.chunks[0]).all()
        assert result.chunks[1] == (5, 5)
        expected = da.from_array(arr, chunks=((25, 25), (10,))).rechunk(
            (None, (5, 5)), method="p2p"
        )
>       assert_eq(await c.compute(result), await c.compute(expected))

distributed/shuffle/tests/test_rechunk.py:706: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P d30e37775e332baa036f77fc7f4c89c6 failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

5 out of 6 runs failed: test_rechunk_with_fully_unknown_dimension[x0-chunks0] (distributed.shuffle.tests.test_rechunk)

artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 997a532b09b4edeae24e25db52e8ece8 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '997a532b09b4edeae24e25db52e8ece8'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…22: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:39695', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(50, 10), dtype=float64, chunksize=(25, 10), chunktype=numpy.ndarray>
chunks = (None, 5)
ws = (<Worker 'tcp://127.0.0.1:35371', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:41665', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/dataframe/__init__.py'>

    @pytest.mark.parametrize(
        "x, chunks",
        [
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
            (da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
            (da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
        """
        pytest.importorskip("pandas")
        dd = pytest.importorskip("dask.dataframe")
        y = dd.from_array(x).values
        result = y.rechunk(chunks, method="p2p")
        expected = x.rechunk(chunks, method="p2p")
    
        assert_chunks_match(result.chunks, expected.chunks)
>       assert_eq(await c.compute(result), await c.compute(expected))

distributed/shuffle/tests/test_rechunk.py:757: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 997a532b09b4edeae24e25db52e8ece8 failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

5 out of 6 runs failed: test_rechunk_with_fully_unknown_dimension[x1-chunks1] (distributed.shuffle.tests.test_rechunk)

artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 997a532b09b4edeae24e25db52e8ece8 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '997a532b09b4edeae24e25db52e8ece8'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…y:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:40587', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(50, 10), dtype=float64, chunksize=(25, 10), chunktype=numpy.ndarray>
chunks = {1: 5}
ws = (<Worker 'tcp://127.0.0.1:38455', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:42433', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/dataframe/__init__.py'>

    @pytest.mark.parametrize(
        "x, chunks",
        [
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
            (da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
            (da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
        """
        pytest.importorskip("pandas")
        dd = pytest.importorskip("dask.dataframe")
        y = dd.from_array(x).values
        result = y.rechunk(chunks, method="p2p")
        expected = x.rechunk(chunks, method="p2p")
    
        assert_chunks_match(result.chunks, expected.chunks)
>       assert_eq(await c.compute(result), await c.compute(expected))

distributed/shuffle/tests/test_rechunk.py:757: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 997a532b09b4edeae24e25db52e8ece8 failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

5 out of 6 runs failed: test_rechunk_with_fully_unknown_dimension[x2-chunks2] (distributed.shuffle.tests.test_rechunk)

artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 997a532b09b4edeae24e25db52e8ece8 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '997a532b09b4edeae24e25db52e8ece8'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…n _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:40857', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(50, 10), dtype=float64, chunksize=(25, 10), chunktype=numpy.ndarray>
chunks = (None, (5, 5))
ws = (<Worker 'tcp://127.0.0.1:44041', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:45303', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/dataframe/__init__.py'>

    @pytest.mark.parametrize(
        "x, chunks",
        [
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
            (da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
            (da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
        """
        pytest.importorskip("pandas")
        dd = pytest.importorskip("dask.dataframe")
        y = dd.from_array(x).values
        result = y.rechunk(chunks, method="p2p")
        expected = x.rechunk(chunks, method="p2p")
    
        assert_chunks_match(result.chunks, expected.chunks)
>       assert_eq(await c.compute(result), await c.compute(expected))

distributed/shuffle/tests/test_rechunk.py:757: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 997a532b09b4edeae24e25db52e8ece8 failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

5 out of 6 runs failed: test_rechunk_with_fully_unknown_dimension[x3-chunks3] (distributed.shuffle.tests.test_rechunk)

artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P cd806ec16ba2329f2cd86e74fd632cc1 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: 'cd806ec16ba2329f2cd86e74fd632cc1'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…22: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:39571', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(100, 10), dtype=float64, chunksize=(5, 10), chunktype=numpy.ndarray>
chunks = (None, 5)
ws = (<Worker 'tcp://127.0.0.1:34065', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:43371', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/dataframe/__init__.py'>

    @pytest.mark.parametrize(
        "x, chunks",
        [
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
            (da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
            (da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
        """
        pytest.importorskip("pandas")
        dd = pytest.importorskip("dask.dataframe")
        y = dd.from_array(x).values
        result = y.rechunk(chunks, method="p2p")
        expected = x.rechunk(chunks, method="p2p")
    
        assert_chunks_match(result.chunks, expected.chunks)
>       assert_eq(await c.compute(result), await c.compute(expected))

distributed/shuffle/tests/test_rechunk.py:757: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P cd806ec16ba2329f2cd86e74fd632cc1 failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

5 out of 6 runs failed: test_rechunk_with_fully_unknown_dimension[x4-chunks4] (distributed.shuffle.tests.test_rechunk)

artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P cd806ec16ba2329f2cd86e74fd632cc1 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: 'cd806ec16ba2329f2cd86e74fd632cc1'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…y:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:42723', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(100, 10), dtype=float64, chunksize=(5, 10), chunktype=numpy.ndarray>
chunks = {1: 5}
ws = (<Worker 'tcp://127.0.0.1:40963', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:41041', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/dataframe/__init__.py'>

    @pytest.mark.parametrize(
        "x, chunks",
        [
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
            (da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
            (da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
        """
        pytest.importorskip("pandas")
        dd = pytest.importorskip("dask.dataframe")
        y = dd.from_array(x).values
        result = y.rechunk(chunks, method="p2p")
        expected = x.rechunk(chunks, method="p2p")
    
        assert_chunks_match(result.chunks, expected.chunks)
>       assert_eq(await c.compute(result), await c.compute(expected))

distributed/shuffle/tests/test_rechunk.py:757: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P cd806ec16ba2329f2cd86e74fd632cc1 failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

5 out of 6 runs failed: test_rechunk_with_fully_unknown_dimension[x5-chunks5] (distributed.shuffle.tests.test_rechunk)

artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P cd806ec16ba2329f2cd86e74fd632cc1 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: 'cd806ec16ba2329f2cd86e74fd632cc1'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…n _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:33355', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(100, 10), dtype=float64, chunksize=(5, 10), chunktype=numpy.ndarray>
chunks = (None, (5, 5))
ws = (<Worker 'tcp://127.0.0.1:38757', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:41363', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/dataframe/__init__.py'>

    @pytest.mark.parametrize(
        "x, chunks",
        [
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
            (da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
            (da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
        """
        pytest.importorskip("pandas")
        dd = pytest.importorskip("dask.dataframe")
        y = dd.from_array(x).values
        result = y.rechunk(chunks, method="p2p")
        expected = x.rechunk(chunks, method="p2p")
    
        assert_chunks_match(result.chunks, expected.chunks)
>       assert_eq(await c.compute(result), await c.compute(expected))

distributed/shuffle/tests/test_rechunk.py:757: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P cd806ec16ba2329f2cd86e74fd632cc1 failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

5 out of 6 runs failed: test_rechunk_with_fully_unknown_dimension[x6-chunks6] (distributed.shuffle.tests.test_rechunk)

artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P d1cb5122ca49df832e6e59ceb0060c25 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: 'd1cb5122ca49df832e6e59ceb0060c25'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…22: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:45147', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(10, 10), dtype=float64, chunksize=(10, 10), chunktype=numpy.ndarray>
chunks = (None, 5)
ws = (<Worker 'tcp://127.0.0.1:33203', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:43661', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/dataframe/__init__.py'>

    @pytest.mark.parametrize(
        "x, chunks",
        [
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
            (da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
            (da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
        """
        pytest.importorskip("pandas")
        dd = pytest.importorskip("dask.dataframe")
        y = dd.from_array(x).values
        result = y.rechunk(chunks, method="p2p")
        expected = x.rechunk(chunks, method="p2p")
    
        assert_chunks_match(result.chunks, expected.chunks)
>       assert_eq(await c.compute(result), await c.compute(expected))

distributed/shuffle/tests/test_rechunk.py:757: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P d1cb5122ca49df832e6e59ceb0060c25 failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

5 out of 6 runs failed: test_rechunk_with_fully_unknown_dimension[x7-chunks7] (distributed.shuffle.tests.test_rechunk)

artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P d1cb5122ca49df832e6e59ceb0060c25 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: 'd1cb5122ca49df832e6e59ceb0060c25'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…y:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:38539', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(10, 10), dtype=float64, chunksize=(10, 10), chunktype=numpy.ndarray>
chunks = {1: 5}
ws = (<Worker 'tcp://127.0.0.1:46847', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:34357', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/dataframe/__init__.py'>

    @pytest.mark.parametrize(
        "x, chunks",
        [
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
            (da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
            (da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
        """
        pytest.importorskip("pandas")
        dd = pytest.importorskip("dask.dataframe")
        y = dd.from_array(x).values
        result = y.rechunk(chunks, method="p2p")
        expected = x.rechunk(chunks, method="p2p")
    
        assert_chunks_match(result.chunks, expected.chunks)
>       assert_eq(await c.compute(result), await c.compute(expected))

distributed/shuffle/tests/test_rechunk.py:757: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P d1cb5122ca49df832e6e59ceb0060c25 failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

5 out of 6 runs failed: test_rechunk_with_fully_unknown_dimension[x8-chunks8] (distributed.shuffle.tests.test_rechunk)

artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P d1cb5122ca49df832e6e59ceb0060c25 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: 'd1cb5122ca49df832e6e59ceb0060c25'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…n _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:37997', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(10, 10), dtype=float64, chunksize=(10, 10), chunktype=numpy.ndarray>
chunks = (None, (5, 5))
ws = (<Worker 'tcp://127.0.0.1:35021', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:42729', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/dataframe/__init__.py'>

    @pytest.mark.parametrize(
        "x, chunks",
        [
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
            (da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
            (da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
        """
        pytest.importorskip("pandas")
        dd = pytest.importorskip("dask.dataframe")
        y = dd.from_array(x).values
        result = y.rechunk(chunks, method="p2p")
        expected = x.rechunk(chunks, method="p2p")
    
        assert_chunks_match(result.chunks, expected.chunks)
>       assert_eq(await c.compute(result), await c.compute(expected))

distributed/shuffle/tests/test_rechunk.py:757: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P d1cb5122ca49df832e6e59ceb0060c25 failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

5 out of 6 runs failed: test_rechunk_with_fully_unknown_dimension[x9-chunks9] (distributed.shuffle.tests.test_rechunk)

artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P a5d6a357fabfab24dace4a1ae3b6eb0d failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: 'a5d6a357fabfab24dace4a1ae3b6eb0d'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:36439', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(10, 10), dtype=float64, chunksize=(10, 2), chunktype=numpy.ndarray>
chunks = (None, 5)
ws = (<Worker 'tcp://127.0.0.1:37419', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:45359', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/dataframe/__init__.py'>

    @pytest.mark.parametrize(
        "x, chunks",
        [
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
            (da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
            (da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
        """
        pytest.importorskip("pandas")
        dd = pytest.importorskip("dask.dataframe")
        y = dd.from_array(x).values
        result = y.rechunk(chunks, method="p2p")
        expected = x.rechunk(chunks, method="p2p")
    
        assert_chunks_match(result.chunks, expected.chunks)
>       assert_eq(await c.compute(result), await c.compute(expected))

distributed/shuffle/tests/test_rechunk.py:757: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P a5d6a357fabfab24dace4a1ae3b6eb0d failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError