Temporarily disable gpuCI update CI job (#8945) #15320
43 fail, 112 skipped, 3 975 pass in 5h 36m 44s
Annotations
Check warning on line 0 in distributed.tests.test_client
github-actions / Unit Test Results
8 out of 9 runs failed: test_persist_async (distributed.tests.test_client)
artifacts/macos-latest-3.12-default-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-ci1/pytest.xml [took 0s]
Raw output
IndexError: tuple index out of range
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:40231', workers: 0, cores: 0, tasks: 0>
a = <Worker 'tcp://127.0.0.1:34811', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>
b = <Worker 'tcp://127.0.0.1:36687', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>
@gen_cluster(client=True)
async def test_persist_async(c, s, a, b):
pytest.importorskip("numpy")
da = pytest.importorskip("dask.array")
x = da.ones((10, 10), chunks=(5, 10))
y = 2 * (x + 1)
assert len(y.dask) == 6
yy = c.persist(y)
assert len(y.dask) == 6
assert len(yy.dask) == 2
assert all(isinstance(v, Future) for v in yy.dask.values())
assert yy.__dask_keys__() == y.__dask_keys__()
g, h = c.compute([y, yy])
> gg, hh = await c.gather([g, h])
distributed/tests/test_client.py:2565:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:2427: in _gather
raise exception.with_traceback(traceback)
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/array/core.py:1284: in finalize
return concatenate3(results)
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/array/core.py:5452: in concatenate3
chunks = chunks_from_arrays(arrays)
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/array/core.py:5237: in chunks_from_arrays
result.append(tuple(shape(deepfirst(a))[dim] for a in arrays))
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import contextlib
import math
import operator
import os
import pickle
import re
import sys
import traceback
import uuid
import warnings
from bisect import bisect
from collections import defaultdict
from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence
from functools import lru_cache, partial, reduce, wraps
from itertools import product, zip_longest
from numbers import Integral, Number
from operator import add, mul
from threading import Lock
from typing import Any, Literal, TypeVar, Union, cast
import numpy as np
from numpy.typing import ArrayLike
from packaging.version import Version
from tlz import accumulate, concat, first, groupby, partition
from tlz.curried import pluck
from toolz import frequencies
from dask import compute, config, core
from dask.array import chunk
from dask.array.chunk import getitem
from dask.array.chunk_types import is_valid_array_chunk, is_valid_chunk_type
# Keep einsum_lookup and tensordot_lookup here for backwards compatibility
from dask.array.dispatch import ( # noqa: F401
concatenate_lookup,
einsum_lookup,
tensordot_lookup,
)
from dask.array.numpy_compat import NUMPY_GE_200, _Recurser
from dask.array.slicing import replace_ellipsis, setitem_array, slice_array
from dask.array.utils import compute_meta, meta_from_array
from dask.base import (
DaskMethodsMixin,
compute_as_if_collection,
dont_optimize,
is_dask_collection,
named_schedulers,
persist,
tokenize,
)
from dask.blockwise import blockwise as core_blockwise
from dask.blockwise import broadcast_dimensions
from dask.context import globalmethod
from dask.core import quote
from dask.delayed import Delayed, delayed
from dask.highlevelgraph import HighLevelGraph, MaterializedLayer
from dask.layers import ArrayBlockIdDep, ArraySliceDep, ArrayValuesDep, reshapelist
from dask.sizeof import sizeof
from dask.typing import Graph, Key, NestedKeys
from dask.utils import (
IndexCallable,
SerializableLock,
cached_cumsum,
cached_property,
concrete,
derived_from,
format_bytes,
funcname,
has_keyword,
is_arraylike,
is_dataframe_like,
is_index_like,
is_integer,
is_series_like,
maybe_pluralize,
ndeepmap,
ndimlist,
parse_bytes,
typename,
)
from dask.widgets import get_template
T_IntOrNaN = Union[int, float] # Should be Union[int, Literal[np.nan]]
DEFAULT_GET = named_schedulers.get("threads", named_schedulers["sync"])
unknown_chunk_message = (
"\n\n"
"A possible solution: "
"https://docs.dask.org/en/latest/array-chunks.html#unknown-chunks\n"
"Summary: to compute chunks sizes, use\n\n"
" x.compute_chunk_sizes() # for Dask Array `x`\n"
" ddf.to_dask_array(lengths=True) # for Dask DataFrame `ddf`"
)
class PerformanceWarning(Warning):
"""A warning given when bad chunking may cause poor performance"""
def getter(a, b, asarray=True, lock=None):
if isinstance(b, tuple) and any(x is None for x in b):
b2 = tuple(x for x in b if x is not None)
b3 = tuple(
None if x is None else slice(None, None)
for x in b
if not isinstance(x, Integral)
)
return getter(a, b2, asarray=asarray, lock=lock)[b3]
if lock:
lock.acquire()
try:
c = a[b]
# Below we special-case `np.matrix` to force a conversion to
# `np.ndarray` and preserve original Dask behavior for `getter`,
# as for all purposes `np.matrix` is array-like and thus
# `is_arraylike` evaluates to `True` in that case.
if asarray and (not is_arraylike(c) or isinstance(c, np.matrix)):
c = np.asarray(c)
finally:
if lock:
lock.release()
return c
def getter_nofancy(a, b, asarray=True, lock=None):
"""A simple wrapper around ``getter``.
Used to indicate to the optimization passes that the backend doesn't
support fancy indexing.
"""
return getter(a, b, asarray=asarray, lock=lock)
def getter_inline(a, b, asarray=True, lock=None):
"""A getter function that optimizations feel comfortable inlining
Slicing operations with this function may be inlined into a graph, such as
in the following rewrite
**Before**
>>> a = x[:10] # doctest: +SKIP
>>> b = a + 1 # doctest: +SKIP
>>> c = a * 2 # doctest: +SKIP
**After**
>>> b = x[:10] + 1 # doctest: +SKIP
>>> c = x[:10] * 2 # doctest: +SKIP
This inlining can be relevant to operations when running off of disk.
"""
return getter(a, b, asarray=asarray, lock=lock)
from dask.array.optimization import fuse_slice, optimize
# __array_function__ dict for mapping aliases and mismatching names
_HANDLED_FUNCTIONS = {}
def implements(*numpy_functions):
"""Register an __array_function__ implementation for dask.array.Array
Register that a function implements the API of a NumPy function (or several
NumPy functions in case of aliases) which is handled with
``__array_function__``.
Parameters
----------
\\*numpy_functions : callables
One or more NumPy functions that are handled by ``__array_function__``
and will be mapped by `implements` to a `dask.array` function.
"""
def decorator(dask_func):
for numpy_function in numpy_functions:
_HANDLED_FUNCTIONS[numpy_function] = dask_func
return dask_func
return decorator
def _should_delegate(self, other) -> bool:
"""Check whether Dask should delegate to the other.
This implementation follows NEP-13:
https://numpy.org/neps/nep-0013-ufunc-overrides.html#behavior-in-combination-with-python-s-binary-operations
"""
if hasattr(other, "__array_ufunc__") and other.__array_ufunc__ is None:
return True
elif (
hasattr(other, "__array_ufunc__")
and not is_valid_array_chunk(other)
# don't delegate to our own parent classes
and not isinstance(self, type(other))
and type(self) is not type(other)
):
return True
return False
def check_if_handled_given_other(f):
"""Check if method is handled by Dask given type of other
Ensures proper deferral to upcast types in dunder operations without
assuming unknown types are automatically downcast types.
"""
@wraps(f)
def wrapper(self, other):
if _should_delegate(self, other):
return NotImplemented
else:
return f(self, other)
return wrapper
def slices_from_chunks(chunks):
"""Translate chunks tuple to a set of slices in product order
>>> slices_from_chunks(((2, 2), (3, 3, 3))) # doctest: +NORMALIZE_WHITESPACE
[(slice(0, 2, None), slice(0, 3, None)),
(slice(0, 2, None), slice(3, 6, None)),
(slice(0, 2, None), slice(6, 9, None)),
(slice(2, 4, None), slice(0, 3, None)),
(slice(2, 4, None), slice(3, 6, None)),
(slice(2, 4, None), slice(6, 9, None))]
"""
cumdims = [cached_cumsum(bds, initial_zero=True) for bds in chunks]
slices = [
[slice(s, s + dim) for s, dim in zip(starts, shapes)]
for starts, shapes in zip(cumdims, chunks)
]
return list(product(*slices))
def graph_from_arraylike(
arr, # Any array-like which supports slicing
chunks,
shape,
name,
getitem=getter,
lock=False,
asarray=True,
dtype=None,
inline_array=False,
) -> HighLevelGraph:
"""
HighLevelGraph for slicing chunks from an array-like according to a chunk pattern.
If ``inline_array`` is True, this make a Blockwise layer of slicing tasks where the
array-like is embedded into every task.,
If ``inline_array`` is False, this inserts the array-like as a standalone value in
a MaterializedLayer, then generates a Blockwise layer of slicing tasks that refer
to it.
>>> dict(graph_from_arraylike(arr, chunks=(2, 3), shape=(4, 6), name="X", inline_array=True)) # doctest: +SKIP
{(arr, 0, 0): (getter, arr, (slice(0, 2), slice(0, 3))),
(arr, 1, 0): (getter, arr, (slice(2, 4), slice(0, 3))),
(arr, 1, 1): (getter, arr, (slice(2, 4), slice(3, 6))),
(arr, 0, 1): (getter, arr, (slice(0, 2), slice(3, 6)))}
>>> dict( # doctest: +SKIP
graph_from_arraylike(arr, chunks=((2, 2), (3, 3)), shape=(4,6), name="X", inline_array=False)
)
{"original-X": arr,
('X', 0, 0): (getter, 'original-X', (slice(0, 2), slice(0, 3))),
('X', 1, 0): (getter, 'original-X', (slice(2, 4), slice(0, 3))),
('X', 1, 1): (getter, 'original-X', (slice(2, 4), slice(3, 6))),
('X', 0, 1): (getter, 'original-X', (slice(0, 2), slice(3, 6)))}
"""
chunks = normalize_chunks(chunks, shape, dtype=dtype)
out_ind = tuple(range(len(shape)))
if (
has_keyword(getitem, "asarray")
and has_keyword(getitem, "lock")
and (not asarray or lock)
):
kwargs = {"asarray": asarray, "lock": lock}
else:
# Common case, drop extra parameters
kwargs = {}
if inline_array:
layer = core_blockwise(
getitem,
name,
out_ind,
arr,
None,
ArraySliceDep(chunks),
out_ind,
numblocks={},
**kwargs,
)
return HighLevelGraph.from_collections(name, layer)
else:
original_name = "original-" + name
layers = {}
layers[original_name] = MaterializedLayer({original_name: arr})
layers[name] = core_blockwise(
getitem,
name,
out_ind,
original_name,
None,
ArraySliceDep(chunks),
out_ind,
numblocks={},
**kwargs,
)
deps = {
original_name: set(),
name: {original_name},
}
return HighLevelGraph(layers, deps)
def dotmany(A, B, leftfunc=None, rightfunc=None, **kwargs):
"""Dot product of many aligned chunks
>>> x = np.array([[1, 2], [1, 2]])
>>> y = np.array([[10, 20], [10, 20]])
>>> dotmany([x, x, x], [y, y, y])
array([[ 90, 180],
[ 90, 180]])
Optionally pass in functions to apply to the left and right chunks
>>> dotmany([x, x, x], [y, y, y], rightfunc=np.transpose)
array([[150, 150],
[150, 150]])
"""
if leftfunc:
A = map(leftfunc, A)
if rightfunc:
B = map(rightfunc, B)
return sum(map(partial(np.dot, **kwargs), A, B))
def _concatenate2(arrays, axes=None):
"""Recursively concatenate nested lists of arrays along axes
Each entry in axes corresponds to each level of the nested list. The
length of axes should correspond to the level of nesting of arrays.
If axes is an empty list or tuple, return arrays, or arrays[0] if
arrays is a list.
>>> x = np.array([[1, 2], [3, 4]])
>>> _concatenate2([x, x], axes=[0])
array([[1, 2],
[3, 4],
[1, 2],
[3, 4]])
>>> _concatenate2([x, x], axes=[1])
array([[1, 2, 1, 2],
[3, 4, 3, 4]])
>>> _concatenate2([[x, x], [x, x]], axes=[0, 1])
array([[1, 2, 1, 2],
[3, 4, 3, 4],
[1, 2, 1, 2],
[3, 4, 3, 4]])
Supports Iterators
>>> _concatenate2(iter([x, x]), axes=[1])
array([[1, 2, 1, 2],
[3, 4, 3, 4]])
Special Case
>>> _concatenate2([x, x], axes=())
array([[1, 2],
[3, 4]])
"""
if axes is None:
axes = []
if axes == ():
if isinstance(arrays, list):
return arrays[0]
else:
return arrays
if isinstance(arrays, Iterator):
arrays = list(arrays)
if not isinstance(arrays, (list, tuple)):
return arrays
if len(axes) > 1:
arrays = [_concatenate2(a, axes=axes[1:]) for a in arrays]
concatenate = concatenate_lookup.dispatch(
type(max(arrays, key=lambda x: getattr(x, "__array_priority__", 0)))
)
if isinstance(arrays[0], dict):
# Handle concatenation of `dict`s, used as a replacement for structured
# arrays when that's not supported by the array library (e.g., CuPy).
keys = list(arrays[0].keys())
assert all(list(a.keys()) == keys for a in arrays)
ret = dict()
for k in keys:
ret[k] = concatenate(list(a[k] for a in arrays), axis=axes[0])
return ret
else:
return concatenate(arrays, axis=axes[0])
def apply_infer_dtype(func, args, kwargs, funcname, suggest_dtype="dtype", nout=None):
"""
Tries to infer output dtype of ``func`` for a small set of input arguments.
Parameters
----------
func: Callable
Function for which output dtype is to be determined
args: List of array like
Arguments to the function, which would usually be used. Only attributes
``ndim`` and ``dtype`` are used.
kwargs: dict
Additional ``kwargs`` to the ``func``
funcname: String
Name of calling function to improve potential error messages
suggest_dtype: None/False or String
If not ``None`` adds suggestion to potential error message to specify a dtype
via the specified kwarg. Defaults to ``'dtype'``.
nout: None or Int
``None`` if function returns single output, integer if many.
Defaults to ``None``.
Returns
-------
: dtype or List of dtype
One or many dtypes (depending on ``nout``)
"""
from dask.array.utils import meta_from_array
# make sure that every arg is an evaluated array
args = [
(
np.ones_like(meta_from_array(x), shape=((1,) * x.ndim), dtype=x.dtype)
if is_arraylike(x)
else x
)
for x in args
]
try:
with np.errstate(all="ignore"):
o = func(*args, **kwargs)
except Exception as e:
exc_type, exc_value, exc_traceback = sys.exc_info()
tb = "".join(traceback.format_tb(exc_traceback))
suggest = (
(
"Please specify the dtype explicitly using the "
"`{dtype}` kwarg.\n\n".format(dtype=suggest_dtype)
)
if suggest_dtype
else ""
)
msg = (
f"`dtype` inference failed in `{funcname}`.\n\n"
f"{suggest}"
"Original error is below:\n"
"------------------------\n"
f"{e!r}\n\n"
"Traceback:\n"
"---------\n"
f"{tb}"
)
else:
msg = None
if msg is not None:
raise ValueError(msg)
return getattr(o, "dtype", type(o)) if nout is None else tuple(e.dtype for e in o)
def normalize_arg(x):
"""Normalize user provided arguments to blockwise or map_blocks
We do a few things:
1. If they are string literals that might collide with blockwise_token then we
quote them
2. IF they are large (as defined by sizeof) then we put them into the
graph on their own by using dask.delayed
"""
if is_dask_collection(x):
return x
elif isinstance(x, str) and re.match(r"_\d+", x):
return delayed(x)
elif isinstance(x, list) and len(x) >= 10:
return delayed(x)
elif sizeof(x) > 1e6:
return delayed(x)
else:
return x
def _pass_extra_kwargs(func, keys, *args, **kwargs):
"""Helper for :func:`dask.array.map_blocks` to pass `block_info` or `block_id`.
For each element of `keys`, a corresponding element of args is changed
to a keyword argument with that key, before all arguments re passed on
to `func`.
"""
kwargs.update(zip(keys, args))
return func(*args[len(keys) :], **kwargs)
def map_blocks(
func,
*args,
name=None,
token=None,
dtype=None,
chunks=None,
drop_axis=None,
new_axis=None,
enforce_ndim=False,
meta=None,
**kwargs,
):
"""Map a function across all blocks of a dask array.
Note that ``map_blocks`` will attempt to automatically determine the output
array type by calling ``func`` on 0-d versions of the inputs. Please refer to
the ``meta`` keyword argument below if you expect that the function will not
succeed when operating on 0-d arrays.
Parameters
----------
func : callable
Function to apply to every block in the array.
If ``func`` accepts ``block_info=`` or ``block_id=``
as keyword arguments, these will be passed dictionaries
containing information about input and output chunks/arrays
during computation. See examples for details.
args : dask arrays or other objects
dtype : np.dtype, optional
The ``dtype`` of the output array. It is recommended to provide this.
If not provided, will be inferred by applying the function to a small
set of fake data.
chunks : tuple, optional
Chunk shape of resulting blocks if the function does not preserve
shape. If not provided, the resulting array is assumed to have the same
block structure as the first input array.
drop_axis : number or iterable, optional
Dimensions lost by the function.
new_axis : number or iterable, optional
New dimensions created by the function. Note that these are applied
after ``drop_axis`` (if present). The size of each chunk along this
dimension will be set to 1. Please specify ``chunks`` if the individual
chunks have a different size.
enforce_ndim : bool, default False
Whether to enforce at runtime that the dimensionality of the array
produced by ``func`` actually matches that of the array returned by
``map_blocks``.
If True, this will raise an error when there is a mismatch.
token : string, optional
The key prefix to use for the output array. If not provided, will be
determined from the function name.
name : string, optional
The key name to use for the output array. Note that this fully
specifies the output key name, and must be unique. If not provided,
will be determined by a hash of the arguments.
meta : array-like, optional
The ``meta`` of the output array, when specified is expected to be an
array of the same type and dtype of that returned when calling ``.compute()``
on the array returned by this function. When not provided, ``meta`` will be
inferred by applying the function to a small set of fake data, usually a
0-d array. It's important to ensure that ``func`` can successfully complete
computation without raising exceptions when 0-d is passed to it, providing
``meta`` will be required otherwise. If the output type is known beforehand
(e.g., ``np.ndarray``, ``cupy.ndarray``), an empty array of such type dtype
can be passed, for example: ``meta=np.array((), dtype=np.int32)``.
**kwargs :
Other keyword arguments to pass to function. Values must be constants
(not dask.arrays)
See Also
--------
dask.array.map_overlap : Generalized operation with overlap between neighbors.
dask.array.blockwise : Generalized operation with control over block alignment.
Examples
--------
>>> import dask.array as da
>>> x = da.arange(6, chunks=3)
>>> x.map_blocks(lambda x: x * 2).compute()
array([ 0, 2, 4, 6, 8, 10])
The ``da.map_blocks`` function can also accept multiple arrays.
>>> d = da.arange(5, chunks=2)
>>> e = da.arange(5, chunks=2)
>>> f = da.map_blocks(lambda a, b: a + b**2, d, e)
>>> f.compute()
array([ 0, 2, 6, 12, 20])
If the function changes shape of the blocks then you must provide chunks
explicitly.
>>> y = x.map_blocks(lambda x: x[::2], chunks=((2, 2),))
You have a bit of freedom in specifying chunks. If all of the output chunk
sizes are the same, you can provide just that chunk size as a single tuple.
>>> a = da.arange(18, chunks=(6,))
>>> b = a.map_blocks(lambda x: x[:3], chunks=(3,))
If the function changes the dimension of the blocks you must specify the
created or destroyed dimensions.
>>> b = a.map_blocks(lambda x: x[None, :, None], chunks=(1, 6, 1),
... new_axis=[0, 2])
If ``chunks`` is specified but ``new_axis`` is not, then it is inferred to
add the necessary number of axes on the left.
Note that ``map_blocks()`` will concatenate chunks along axes specified by
the keyword parameter ``drop_axis`` prior to applying the function.
This is illustrated in the figure below:
.. image:: /images/map_blocks_drop_axis.png
Due to memory-size-constraints, it is often not advisable to use ``drop_axis``
on an axis that is chunked. In that case, it is better not to use
``map_blocks`` but rather
``dask.array.reduction(..., axis=dropped_axes, concatenate=False)`` which
maintains a leaner memory footprint while it drops any axis.
Map_blocks aligns blocks by block positions without regard to shape. In the
following example we have two arrays with the same number of blocks but
with different shape and chunk sizes.
>>> x = da.arange(1000, chunks=(100,))
>>> y = da.arange(100, chunks=(10,))
The relevant attribute to match is numblocks.
>>> x.numblocks
(10,)
>>> y.numblocks
(10,)
If these match (up to broadcasting rules) then we can map arbitrary
functions across blocks
>>> def func(a, b):
... return np.array([a.max(), b.max()])
>>> da.map_blocks(func, x, y, chunks=(2,), dtype='i8')
dask.array<func, shape=(20,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray>
>>> _.compute()
array([ 99, 9, 199, 19, 299, 29, 399, 39, 499, 49, 599, 59, 699,
69, 799, 79, 899, 89, 999, 99])
Your block function can get information about where it is in the array by
accepting a special ``block_info`` or ``block_id`` keyword argument.
During computation, they will contain information about each of the input
and output chunks (and dask arrays) relevant to each call of ``func``.
>>> def func(block_info=None):
... pass
This will receive the following information:
>>> block_info # doctest: +SKIP
{0: {'shape': (1000,),
'num-chunks': (10,),
'chunk-location': (4,),
'array-location': [(400, 500)]},
None: {'shape': (1000,),
'num-chunks': (10,),
'chunk-location': (4,),
'array-location': [(400, 500)],
'chunk-shape': (100,),
'dtype': dtype('float64')}}
The keys to the ``block_info`` dictionary indicate which is the input and
output Dask array:
- **Input Dask array(s):** ``block_info[0]`` refers to the first input Dask array.
The dictionary key is ``0`` because that is the argument index corresponding
to the first input Dask array.
In cases where multiple Dask arrays have been passed as input to the function,
you can access them with the number corresponding to the input argument,
eg: ``block_info[1]``, ``block_info[2]``, etc.
(Note that if you pass multiple Dask arrays as input to map_blocks,
the arrays must match each other by having matching numbers of chunks,
along corresponding dimensions up to broadcasting rules.)
- **Output Dask array:** ``block_info[None]`` refers to the output Dask array,
and contains information about the output chunks.
The output chunk shape and dtype may may be different than the input chunks.
For each dask array, ``block_info`` describes:
- ``shape``: the shape of the full Dask array,
- ``num-chunks``: the number of chunks of the full array in each dimension,
- ``chunk-location``: the chunk location (for example the fourth chunk over
in the first dimension), and
- ``array-location``: the array location within the full Dask array
(for example the slice corresponding to ``40:50``).
In addition to these, there are two extra parameters described by
``block_info`` for the output array (in ``block_info[None]``):
- ``chunk-shape``: the output chunk shape, and
- ``dtype``: the output dtype.
These features can be combined to synthesize an array from scratch, for
example:
>>> def func(block_info=None):
... loc = block_info[None]['array-location'][0]
... return np.arange(loc[0], loc[1])
>>> da.map_blocks(func, chunks=((4, 4),), dtype=np.float64)
dask.array<func, shape=(8,), dtype=float64, chunksize=(4,), chunktype=numpy.ndarray>
>>> _.compute()
array([0, 1, 2, 3, 4, 5, 6, 7])
``block_id`` is similar to ``block_info`` but contains only the ``chunk_location``:
>>> def func(block_id=None):
... pass
This will receive the following information:
>>> block_id # doctest: +SKIP
(4, 3)
You may specify the key name prefix of the resulting task in the graph with
the optional ``token`` keyword argument.
>>> x.map_blocks(lambda x: x + 1, name='increment')
dask.array<increment, shape=(1000,), dtype=int64, chunksize=(100,), chunktype=numpy.ndarray>
For functions that may not handle 0-d arrays, it's also possible to specify
``meta`` with an empty array matching the type of the expected result. In
the example below, ``func`` will result in an ``IndexError`` when computing
``meta``:
>>> rng = da.random.default_rng()
>>> da.map_blocks(lambda x: x[2], rng.random(5), meta=np.array(()))
dask.array<lambda, shape=(5,), dtype=float64, chunksize=(5,), chunktype=numpy.ndarray>
Similarly, it's possible to specify a non-NumPy array to ``meta``, and provide
a ``dtype``:
>>> import cupy # doctest: +SKIP
>>> rng = da.random.default_rng(cupy.random.default_rng()) # doctest: +SKIP
>>> dt = np.float32
>>> da.map_blocks(lambda x: x[2], rng.random(5, dtype=dt), meta=cupy.array((), dtype=dt)) # doctest: +SKIP
dask.array<lambda, shape=(5,), dtype=float32, chunksize=(5,), chunktype=cupy.ndarray>
"""
if drop_axis is None:
drop_axis = []
if not callable(func):
msg = (
"First argument must be callable function, not %s\n"
"Usage: da.map_blocks(function, x)\n"
" or: da.map_blocks(function, x, y, z)"
)
raise TypeError(msg % type(func).__name__)
if token:
warnings.warn(
"The `token=` keyword to `map_blocks` has been moved to `name=`. "
"Please use `name=` instead as the `token=` keyword will be removed "
"in a future release.",
category=FutureWarning,
)
name = token
name = f"{name or funcname(func)}-{tokenize(func, dtype, chunks, drop_axis, new_axis, *args, **kwargs)}"
new_axes = {}
if isinstance(drop_axis, Number):
drop_axis = [drop_axis]
if isinstanc…tack
"""
from dask.array import wrap
seq = [asarray(a, allow_unknown_chunksizes=allow_unknown_chunksizes) for a in seq]
if not seq:
raise ValueError("Need array(s) to concatenate")
if axis is None:
seq = [a.flatten() for a in seq]
axis = 0
seq_metas = [meta_from_array(s) for s in seq]
_concatenate = concatenate_lookup.dispatch(
type(max(seq_metas, key=lambda x: getattr(x, "__array_priority__", 0)))
)
meta = _concatenate(seq_metas, axis=axis)
# Promote types to match meta
seq = [a.astype(meta.dtype) for a in seq]
# Find output array shape
ndim = len(seq[0].shape)
shape = tuple(
sum(a.shape[i] for a in seq) if i == axis else seq[0].shape[i]
for i in range(ndim)
)
# Drop empty arrays
seq2 = [a for a in seq if a.size]
if not seq2:
seq2 = seq
if axis < 0:
axis = ndim + axis
if axis >= ndim:
msg = (
"Axis must be less than than number of dimensions"
"\nData has %d dimensions, but got axis=%d"
)
raise ValueError(msg % (ndim, axis))
n = len(seq2)
if n == 0:
try:
return wrap.empty_like(meta, shape=shape, chunks=shape, dtype=meta.dtype)
except TypeError:
return wrap.empty(shape, chunks=shape, dtype=meta.dtype)
elif n == 1:
return seq2[0]
if not allow_unknown_chunksizes and not all(
i == axis or all(x.shape[i] == seq2[0].shape[i] for x in seq2)
for i in range(ndim)
):
if any(map(np.isnan, seq2[0].shape)):
raise ValueError(
"Tried to concatenate arrays with unknown"
" shape %s.\n\nTwo solutions:\n"
" 1. Force concatenation pass"
" allow_unknown_chunksizes=True.\n"
" 2. Compute shapes with "
"[x.compute_chunk_sizes() for x in seq]" % str(seq2[0].shape)
)
raise ValueError("Shapes do not align: %s", [x.shape for x in seq2])
inds = [list(range(ndim)) for i in range(n)]
for i, ind in enumerate(inds):
ind[axis] = -(i + 1)
uc_args = list(concat(zip(seq2, inds)))
_, seq2 = unify_chunks(*uc_args, warn=False)
bds = [a.chunks for a in seq2]
chunks = (
seq2[0].chunks[:axis]
+ (sum((bd[axis] for bd in bds), ()),)
+ seq2[0].chunks[axis + 1 :]
)
cum_dims = [0] + list(accumulate(add, [len(a.chunks[axis]) for a in seq2]))
names = [a.name for a in seq2]
name = "concatenate-" + tokenize(names, axis)
keys = list(product([name], *[range(len(bd)) for bd in chunks]))
values = [
(names[bisect(cum_dims, key[axis + 1]) - 1],)
+ key[1 : axis + 1]
+ (key[axis + 1] - cum_dims[bisect(cum_dims, key[axis + 1]) - 1],)
+ key[axis + 2 :]
for key in keys
]
dsk = dict(zip(keys, values))
graph = HighLevelGraph.from_collections(name, dsk, dependencies=seq2)
return Array(graph, name, chunks, meta=meta)
def load_store_chunk(
x: Any,
out: Any,
index: slice,
lock: Any,
return_stored: bool,
load_stored: bool,
):
"""
A function inserted in a Dask graph for storing a chunk.
Parameters
----------
x: array-like
An array (potentially a NumPy one)
out: array-like
Where to store results.
index: slice-like
Where to store result from ``x`` in ``out``.
lock: Lock-like or False
Lock to use before writing to ``out``.
return_stored: bool
Whether to return ``out``.
load_stored: bool
Whether to return the array stored in ``out``.
Ignored if ``return_stored`` is not ``True``.
Returns
-------
If return_stored=True and load_stored=False
out
If return_stored=True and load_stored=True
out[index]
If return_stored=False and compute=False
None
Examples
--------
>>> a = np.ones((5, 6))
>>> b = np.empty(a.shape)
>>> load_store_chunk(a, b, (slice(None), slice(None)), False, False, False)
"""
if lock:
lock.acquire()
try:
if x is not None and x.size != 0:
if is_arraylike(x):
out[index] = x
else:
out[index] = np.asanyarray(x)
if return_stored and load_stored:
return out[index]
elif return_stored and not load_stored:
return out
else:
return None
finally:
if lock:
lock.release()
def store_chunk(
x: ArrayLike, out: ArrayLike, index: slice, lock: Any, return_stored: bool
):
return load_store_chunk(x, out, index, lock, return_stored, False)
A = TypeVar("A", bound=ArrayLike)
def load_chunk(out: A, index: slice, lock: Any) -> A:
return load_store_chunk(None, out, index, lock, True, True)
def insert_to_ooc(
keys: list,
chunks: tuple[tuple[int, ...], ...],
out: ArrayLike,
name: str,
*,
lock: Lock | bool = True,
region: tuple[slice, ...] | slice | None = None,
return_stored: bool = False,
load_stored: bool = False,
) -> dict:
"""
Creates a Dask graph for storing chunks from ``arr`` in ``out``.
Parameters
----------
keys: list
Dask keys of the input array
chunks: tuple
Dask chunks of the input array
out: array-like
Where to store results to
name: str
First element of dask keys
lock: Lock-like or bool, optional
Whether to lock or with what (default is ``True``,
which means a :class:`threading.Lock` instance).
region: slice-like, optional
Where in ``out`` to store ``arr``'s results
(default is ``None``, meaning all of ``out``).
return_stored: bool, optional
Whether to return ``out``
(default is ``False``, meaning ``None`` is returned).
load_stored: bool, optional
Whether to handling loading from ``out`` at the same time.
Ignored if ``return_stored`` is not ``True``.
(default is ``False``, meaning defer to ``return_stored``).
Returns
-------
dask graph of store operation
Examples
--------
>>> import dask.array as da
>>> d = da.ones((5, 6), chunks=(2, 3))
>>> a = np.empty(d.shape)
>>> insert_to_ooc(d.__dask_keys__(), d.chunks, a, "store-123") # doctest: +SKIP
"""
if lock is True:
lock = Lock()
slices = slices_from_chunks(chunks)
if region:
slices = [fuse_slice(region, slc) for slc in slices]
if return_stored and load_stored:
func = load_store_chunk
args = (load_stored,)
else:
func = store_chunk # type: ignore
args = () # type: ignore
dsk = {
(name,) + t[1:]: (func, t, out, slc, lock, return_stored) + args
for t, slc in zip(core.flatten(keys), slices)
}
return dsk
def retrieve_from_ooc(
keys: Collection[Key], dsk_pre: Graph, dsk_post: Graph
) -> dict[tuple, Any]:
"""
Creates a Dask graph for loading stored ``keys`` from ``dsk``.
Parameters
----------
keys: Collection
A sequence containing Dask graph keys to load
dsk_pre: Mapping
A Dask graph corresponding to a Dask Array before computation
dsk_post: Mapping
A Dask graph corresponding to a Dask Array after computation
Examples
--------
>>> import dask.array as da
>>> d = da.ones((5, 6), chunks=(2, 3))
>>> a = np.empty(d.shape)
>>> g = insert_to_ooc(d.__dask_keys__(), d.chunks, a, "store-123")
>>> retrieve_from_ooc(g.keys(), g, {k: k for k in g.keys()}) # doctest: +SKIP
"""
load_dsk = {
("load-" + k[0],) + k[1:]: (load_chunk, dsk_post[k]) + dsk_pre[k][3:-1] # type: ignore
for k in keys
}
return load_dsk
def _as_dtype(a, dtype):
if dtype is None:
return a
else:
return a.astype(dtype)
def asarray(
a, allow_unknown_chunksizes=False, dtype=None, order=None, *, like=None, **kwargs
):
"""Convert the input to a dask array.
Parameters
----------
a : array-like
Input data, in any form that can be converted to a dask array. This
includes lists, lists of tuples, tuples, tuples of tuples, tuples of
lists and ndarrays.
allow_unknown_chunksizes: bool
Allow unknown chunksizes, such as come from converting from dask
dataframes. Dask.array is unable to verify that chunks line up. If
data comes from differently aligned sources then this can cause
unexpected results.
dtype : data-type, optional
By default, the data-type is inferred from the input data.
order : {‘C’, ‘F’, ‘A’, ‘K’}, optional
Memory layout. ‘A’ and ‘K’ depend on the order of input array a.
‘C’ row-major (C-style), ‘F’ column-major (Fortran-style) memory
representation. ‘A’ (any) means ‘F’ if a is Fortran contiguous, ‘C’
otherwise ‘K’ (keep) preserve input order. Defaults to ‘C’.
like: array-like
Reference object to allow the creation of Dask arrays with chunks
that are not NumPy arrays. If an array-like passed in as ``like``
supports the ``__array_function__`` protocol, the chunk type of the
resulting array will be defined by it. In this case, it ensures the
creation of a Dask array compatible with that passed in via this
argument. If ``like`` is a Dask array, the chunk type of the
resulting array will be defined by the chunk type of ``like``.
Requires NumPy 1.20.0 or higher.
Returns
-------
out : dask array
Dask array interpretation of a.
Examples
--------
>>> import dask.array as da
>>> import numpy as np
>>> x = np.arange(3)
>>> da.asarray(x)
dask.array<array, shape=(3,), dtype=int64, chunksize=(3,), chunktype=numpy.ndarray>
>>> y = [[1, 2, 3], [4, 5, 6]]
>>> da.asarray(y)
dask.array<array, shape=(2, 3), dtype=int64, chunksize=(2, 3), chunktype=numpy.ndarray>
.. warning::
`order` is ignored if `a` is an `Array`, has the attribute ``to_dask_array``,
or is a list or tuple of `Array`'s.
"""
if like is None:
if isinstance(a, Array):
return _as_dtype(a, dtype)
elif hasattr(a, "to_dask_array"):
return _as_dtype(a.to_dask_array(), dtype)
elif type(a).__module__.split(".")[0] == "xarray" and hasattr(a, "data"):
return _as_dtype(asarray(a.data, order=order), dtype)
elif isinstance(a, (list, tuple)) and any(isinstance(i, Array) for i in a):
return _as_dtype(
stack(a, allow_unknown_chunksizes=allow_unknown_chunksizes), dtype
)
elif not isinstance(getattr(a, "shape", None), Iterable):
a = np.asarray(a, dtype=dtype, order=order)
else:
like_meta = meta_from_array(like)
if isinstance(a, Array):
return a.map_blocks(np.asarray, like=like_meta, dtype=dtype, order=order)
else:
a = np.asarray(a, like=like_meta, dtype=dtype, order=order)
return from_array(a, getitem=getter_inline, **kwargs)
def asanyarray(a, dtype=None, order=None, *, like=None, inline_array=False):
"""Convert the input to a dask array.
Subclasses of ``np.ndarray`` will be passed through as chunks unchanged.
Parameters
----------
a : array-like
Input data, in any form that can be converted to a dask array. This
includes lists, lists of tuples, tuples, tuples of tuples, tuples of
lists and ndarrays.
dtype : data-type, optional
By default, the data-type is inferred from the input data.
order : {‘C’, ‘F’, ‘A’, ‘K’}, optional
Memory layout. ‘A’ and ‘K’ depend on the order of input array a.
‘C’ row-major (C-style), ‘F’ column-major (Fortran-style) memory
representation. ‘A’ (any) means ‘F’ if a is Fortran contiguous, ‘C’
otherwise ‘K’ (keep) preserve input order. Defaults to ‘C’.
like: array-like
Reference object to allow the creation of Dask arrays with chunks
that are not NumPy arrays. If an array-like passed in as ``like``
supports the ``__array_function__`` protocol, the chunk type of the
resulting array will be defined by it. In this case, it ensures the
creation of a Dask array compatible with that passed in via this
argument. If ``like`` is a Dask array, the chunk type of the
resulting array will be defined by the chunk type of ``like``.
Requires NumPy 1.20.0 or higher.
inline_array:
Whether to inline the array in the resulting dask graph. For more information,
see the documentation for ``dask.array.from_array()``.
Returns
-------
out : dask array
Dask array interpretation of a.
Examples
--------
>>> import dask.array as da
>>> import numpy as np
>>> x = np.arange(3)
>>> da.asanyarray(x)
dask.array<array, shape=(3,), dtype=int64, chunksize=(3,), chunktype=numpy.ndarray>
>>> y = [[1, 2, 3], [4, 5, 6]]
>>> da.asanyarray(y)
dask.array<array, shape=(2, 3), dtype=int64, chunksize=(2, 3), chunktype=numpy.ndarray>
.. warning::
`order` is ignored if `a` is an `Array`, has the attribute ``to_dask_array``,
or is a list or tuple of `Array`'s.
"""
if like is None:
if isinstance(a, Array):
return _as_dtype(a, dtype)
elif hasattr(a, "to_dask_array"):
return _as_dtype(a.to_dask_array(), dtype)
elif type(a).__module__.split(".")[0] == "xarray" and hasattr(a, "data"):
return _as_dtype(asarray(a.data, order=order), dtype)
elif isinstance(a, (list, tuple)) and any(isinstance(i, Array) for i in a):
return _as_dtype(stack(a), dtype)
elif not isinstance(getattr(a, "shape", None), Iterable):
a = np.asanyarray(a, dtype=dtype, order=order)
else:
like_meta = meta_from_array(like)
if isinstance(a, Array):
return a.map_blocks(np.asanyarray, like=like_meta, dtype=dtype, order=order)
else:
a = np.asanyarray(a, like=like_meta, dtype=dtype, order=order)
return from_array(
a,
chunks=a.shape,
getitem=getter_inline,
asarray=False,
inline_array=inline_array,
)
def is_scalar_for_elemwise(arg):
"""
>>> is_scalar_for_elemwise(42)
True
>>> is_scalar_for_elemwise('foo')
True
>>> is_scalar_for_elemwise(True)
True
>>> is_scalar_for_elemwise(np.array(42))
True
>>> is_scalar_for_elemwise([1, 2, 3])
True
>>> is_scalar_for_elemwise(np.array([1, 2, 3]))
False
>>> is_scalar_for_elemwise(from_array(np.array(0), chunks=()))
False
>>> is_scalar_for_elemwise(np.dtype('i4'))
True
"""
# the second half of shape_condition is essentially just to ensure that
# dask series / frame are treated as scalars in elemwise.
maybe_shape = getattr(arg, "shape", None)
shape_condition = not isinstance(maybe_shape, Iterable) or any(
is_dask_collection(x) for x in maybe_shape
)
return (
np.isscalar(arg)
or shape_condition
or isinstance(arg, np.dtype)
or (isinstance(arg, np.ndarray) and arg.ndim == 0)
)
def broadcast_shapes(*shapes):
"""
Determines output shape from broadcasting arrays.
Parameters
----------
shapes : tuples
The shapes of the arguments.
Returns
-------
output_shape : tuple
Raises
------
ValueError
If the input shapes cannot be successfully broadcast together.
"""
if len(shapes) == 1:
return shapes[0]
out = []
for sizes in zip_longest(*map(reversed, shapes), fillvalue=-1):
if np.isnan(sizes).any():
dim = np.nan
else:
dim = 0 if 0 in sizes else np.max(sizes).item()
if any(i not in [-1, 0, 1, dim] and not np.isnan(i) for i in sizes):
raise ValueError(
"operands could not be broadcast together with "
"shapes {}".format(" ".join(map(str, shapes)))
)
out.append(dim)
return tuple(reversed(out))
def elemwise(op, *args, out=None, where=True, dtype=None, name=None, **kwargs):
"""Apply an elementwise ufunc-like function blockwise across arguments.
Like numpy ufuncs, broadcasting rules are respected.
Parameters
----------
op : callable
The function to apply. Should be numpy ufunc-like in the parameters
that it accepts.
*args : Any
Arguments to pass to `op`. Non-dask array-like objects are first
converted to dask arrays, then all arrays are broadcast together before
applying the function blockwise across all arguments. Any scalar
arguments are passed as-is following normal numpy ufunc behavior.
out : dask array, optional
If out is a dask.array then this overwrites the contents of that array
with the result.
where : array_like, optional
An optional boolean mask marking locations where the ufunc should be
applied. Can be a scalar, dask array, or any other array-like object.
Mirrors the ``where`` argument to numpy ufuncs, see e.g. ``numpy.add``
for more information.
dtype : dtype, optional
If provided, overrides the output array dtype.
name : str, optional
A unique key name to use when building the backing dask graph. If not
provided, one will be automatically generated based on the input
arguments.
Examples
--------
>>> elemwise(add, x, y) # doctest: +SKIP
>>> elemwise(sin, x) # doctest: +SKIP
>>> elemwise(sin, x, out=dask_array) # doctest: +SKIP
See Also
--------
blockwise
"""
if kwargs:
raise TypeError(
f"{op.__name__} does not take the following keyword arguments "
f"{sorted(kwargs)}"
)
out = _elemwise_normalize_out(out)
where = _elemwise_normalize_where(where)
args = [np.asarray(a) if isinstance(a, (list, tuple)) else a for a in args]
shapes = []
for arg in args:
shape = getattr(arg, "shape", ())
if any(is_dask_collection(x) for x in shape):
# Want to exclude Delayed shapes and dd.Scalar
shape = ()
shapes.append(shape)
if isinstance(where, Array):
shapes.append(where.shape)
if isinstance(out, Array):
shapes.append(out.shape)
shapes = [s if isinstance(s, Iterable) else () for s in shapes]
out_ndim = len(
broadcast_shapes(*shapes)
) # Raises ValueError if dimensions mismatch
expr_inds = tuple(range(out_ndim))[::-1]
if dtype is not None:
need_enforce_dtype = True
else:
# We follow NumPy's rules for dtype promotion, which special cases
# scalars and 0d ndarrays (which it considers equivalent) by using
# their values to compute the result dtype:
# https://github.com/numpy/numpy/issues/6240
# We don't inspect the values of 0d dask arrays, because these could
# hold potentially very expensive calculations. Instead, we treat
# them just like other arrays, and if necessary cast the result of op
# to match.
vals = [
(
np.empty((1,) * max(1, a.ndim), dtype=a.dtype)
if not is_scalar_for_elemwise(a)
else a
)
for a in args
]
try:
dtype = apply_infer_dtype(op, vals, {}, "elemwise", suggest_dtype=False)
except Exception:
return NotImplemented
need_enforce_dtype = any(
not is_scalar_for_elemwise(a) and a.ndim == 0 for a in args
)
if not name:
name = f"{funcname(op)}-{tokenize(op, dtype, *args, where)}"
blockwise_kwargs = dict(dtype=dtype, name=name, token=funcname(op).strip("_"))
if where is not True:
blockwise_kwargs["elemwise_where_function"] = op
op = _elemwise_handle_where
args.extend([where, out])
if need_enforce_dtype:
blockwise_kwargs["enforce_dtype"] = dtype
blockwise_kwargs["enforce_dtype_function"] = op
op = _enforce_dtype
result = blockwise(
op,
expr_inds,
*concat(
(a, tuple(range(a.ndim)[::-1]) if not is_scalar_for_elemwise(a) else None)
for a in args
),
**blockwise_kwargs,
)
return handle_out(out, result)
def _elemwise_normalize_where(where):
if where is True:
return True
elif where is False or where is None:
return False
return asarray(where)
def _elemwise_handle_where(*args, **kwargs):
function = kwargs.pop("elemwise_where_function")
*args, where, out = args
if hasattr(out, "copy"):
out = out.copy()
return function(*args, where=where, out=out, **kwargs)
def _elemwise_normalize_out(out):
if isinstance(out, tuple):
if len(out) == 1:
out = out[0]
elif len(out) > 1:
raise NotImplementedError("The out parameter is not fully supported")
else:
out = None
if not (out is None or isinstance(out, Array)):
raise NotImplementedError(
f"The out parameter is not fully supported."
f" Received type {type(out).__name__}, expected Dask Array"
)
return out
def handle_out(out, result):
"""Handle out parameters
If out is a dask.array then this overwrites the contents of that array with
the result
"""
out = _elemwise_normalize_out(out)
if isinstance(out, Array):
if out.shape != result.shape:
raise ValueError(
"Mismatched shapes between result and out parameter. "
"out=%s, result=%s" % (str(out.shape), str(result.shape))
)
out._chunks = result.chunks
out.dask = result.dask
out._meta = result._meta
out._name = result.name
return out
else:
return result
def _enforce_dtype(*args, **kwargs):
"""Calls a function and converts its result to the given dtype.
The parameters have deliberately been given unwieldy names to avoid
clashes with keyword arguments consumed by blockwise
A dtype of `object` is treated as a special case and not enforced,
because it is used as a dummy value in some places when the result will
not be a block in an Array.
Parameters
----------
enforce_dtype : dtype
Result dtype
enforce_dtype_function : callable
The wrapped function, which will be passed the remaining arguments
"""
dtype = kwargs.pop("enforce_dtype")
function = kwargs.pop("enforce_dtype_function")
result = function(*args, **kwargs)
if hasattr(result, "dtype") and dtype != result.dtype and dtype != object:
if not np.can_cast(result, dtype, casting="same_kind"):
raise ValueError(
"Inferred dtype from function %r was %r "
"but got %r, which can't be cast using "
"casting='same_kind'"
% (funcname(function), str(dtype), str(result.dtype))
)
if np.isscalar(result):
# scalar astype method doesn't take the keyword arguments, so
# have to convert via 0-dimensional array and back.
result = result.astype(dtype)
else:
try:
result = result.astype(dtype, copy=False)
except TypeError:
# Missing copy kwarg
result = result.astype(dtype)
return result
def broadcast_to(x, shape, chunks=None, meta=None):
"""Broadcast an array to a new shape.
Parameters
----------
x : array_like
The array to broadcast.
shape : tuple
The shape of the desired array.
chunks : tuple, optional
If provided, then the result will use these chunks instead of the same
chunks as the source array. Setting chunks explicitly as part of
broadcast_to is more efficient than rechunking afterwards. Chunks are
only allowed to differ from the original shape along dimensions that
are new on the result or have size 1 the input array.
meta : empty ndarray
empty ndarray created with same NumPy backend, ndim and dtype as the
Dask Array being created (overrides dtype)
Returns
-------
broadcast : dask array
See Also
--------
:func:`numpy.broadcast_to`
"""
x = asarray(x)
shape = tuple(shape)
if meta is None:
meta = meta_from_array(x)
if x.shape == shape and (chunks is None or chunks == x.chunks):
return x
ndim_new = len(shape) - x.ndim
if ndim_new < 0 or any(
new != old for new, old in zip(shape[ndim_new:], x.shape) if old != 1
):
raise ValueError(f"cannot broadcast shape {x.shape} to shape {shape}")
if chunks is None:
chunks = tuple((s,) for s in shape[:ndim_new]) + tuple(
bd if old > 1 else (new,)
for bd, old, new in zip(x.chunks, x.shape, shape[ndim_new:])
)
else:
chunks = normalize_chunks(
chunks, shape, dtype=x.dtype, previous_chunks=x.chunks
)
for old_bd, new_bd in zip(x.chunks, chunks[ndim_new:]):
if old_bd != new_bd and old_bd != (1,):
raise ValueError(
"cannot broadcast chunks %s to chunks %s: "
"new chunks must either be along a new "
"dimension or a dimension of size 1" % (x.chunks, chunks)
)
name = "broadcast_to-" + tokenize(x, shape, chunks)
dsk = {}
enumerated_chunks = product(*(enumerate(bds) for bds in chunks))
for new_index, chunk_shape in (zip(*ec) for ec in enumerated_chunks):
old_index = tuple(
0 if bd == (1,) else i for bd, i in zip(x.chunks, new_index[ndim_new:])
)
old_key = (x.name,) + old_index
new_key = (name,) + new_index
dsk[new_key] = (np.broadcast_to, old_key, quote(chunk_shape))
graph = HighLevelGraph.from_collections(name, dsk, dependencies=[x])
return Array(graph, name, chunks, dtype=x.dtype, meta=meta)
@derived_from(np)
def broadcast_arrays(*args, subok=False):
subok = bool(subok)
to_array = asanyarray if subok else asarray
args = tuple(to_array(e) for e in args)
# Unify uneven chunking
inds = [list(reversed(range(x.ndim))) for x in args]
uc_args = concat(zip(args, inds))
_, args = unify_chunks(*uc_args, warn=False)
shape = broadcast_shapes(*(e.shape for e in args))
chunks = broadcast_chunks(*(e.chunks for e in args))
if NUMPY_GE_200:
result = tuple(broadcast_to(e, shape=shape, chunks=chunks) for e in args)
else:
result = [broadcast_to(e, shape=shape, chunks=chunks) for e in args]
return result
def offset_func(func, offset, *args):
"""Offsets inputs by offset
>>> double = lambda x: x * 2
>>> f = offset_func(double, (10,))
>>> f(1)
22
>>> f(300)
620
"""
def _offset(*args):
args2 = list(map(add, args, offset))
return func(*args2)
with contextlib.suppress(Exception):
_offset.__name__ = "offset_" + func.__name__
return _offset
def chunks_from_arrays(arrays):
"""Chunks tuple from nested list of arrays
>>> x = np.array([1, 2])
>>> chunks_from_arrays([x, x])
((2, 2),)
>>> x = np.array([[1, 2]])
>>> chunks_from_arrays([[x], [x]])
((1, 1), (2,))
>>> x = np.array([[1, 2]])
>>> chunks_from_arrays([[x, x]])
((1,), (2, 2))
>>> chunks_from_arrays([1, 1])
((1, 1),)
"""
if not arrays:
return ()
result = []
dim = 0
def shape(x):
try:
return x.shape if x.shape else (1,)
except AttributeError:
return (1,)
while isinstance(arrays, (list, tuple)):
> result.append(tuple(shape(deepfirst(a))[dim] for a in arrays))
E IndexError: tuple index out of range
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/array/core.py:5237: IndexError
Check warning on line 0 in distributed.tests.test_client
github-actions / Unit Test Results
8 out of 9 runs failed: test_release_persisted_collection (distributed.tests.test_client)
artifacts/macos-latest-3.12-default-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-ci1/pytest.xml [took 0s]
Raw output
IndexError: tuple index out of range
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:38475', workers: 0, cores: 0, tasks: 0>
a = <Worker 'tcp://127.0.0.1:42523', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>
b = <Worker 'tcp://127.0.0.1:39277', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>
@gen_cluster(client=True)
async def test_release_persisted_collection(c, s, a, b):
np = pytest.importorskip("numpy")
da = pytest.importorskip("dask.array")
arr = c.persist(da.random.random((10,), chunks=(10,)))
await wait(arr)
_release_persisted(arr)
while s.tasks:
await asyncio.sleep(0.01)
with pytest.raises(CancelledError):
> await c.compute(arr)
distributed/tests/test_client.py:8235:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/array/core.py:1284: in finalize
return concatenate3(results)
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/array/core.py:5452: in concatenate3
chunks = chunks_from_arrays(arrays)
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/array/core.py:5237: in chunks_from_arrays
result.append(tuple(shape(deepfirst(a))[dim] for a in arrays))
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import contextlib
import math
import operator
import os
import pickle
import re
import sys
import traceback
import uuid
import warnings
from bisect import bisect
from collections import defaultdict
from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence
from functools import lru_cache, partial, reduce, wraps
from itertools import product, zip_longest
from numbers import Integral, Number
from operator import add, mul
from threading import Lock
from typing import Any, Literal, TypeVar, Union, cast
import numpy as np
from numpy.typing import ArrayLike
from packaging.version import Version
from tlz import accumulate, concat, first, groupby, partition
from tlz.curried import pluck
from toolz import frequencies
from dask import compute, config, core
from dask.array import chunk
from dask.array.chunk import getitem
from dask.array.chunk_types import is_valid_array_chunk, is_valid_chunk_type
# Keep einsum_lookup and tensordot_lookup here for backwards compatibility
from dask.array.dispatch import ( # noqa: F401
concatenate_lookup,
einsum_lookup,
tensordot_lookup,
)
from dask.array.numpy_compat import NUMPY_GE_200, _Recurser
from dask.array.slicing import replace_ellipsis, setitem_array, slice_array
from dask.array.utils import compute_meta, meta_from_array
from dask.base import (
DaskMethodsMixin,
compute_as_if_collection,
dont_optimize,
is_dask_collection,
named_schedulers,
persist,
tokenize,
)
from dask.blockwise import blockwise as core_blockwise
from dask.blockwise import broadcast_dimensions
from dask.context import globalmethod
from dask.core import quote
from dask.delayed import Delayed, delayed
from dask.highlevelgraph import HighLevelGraph, MaterializedLayer
from dask.layers import ArrayBlockIdDep, ArraySliceDep, ArrayValuesDep, reshapelist
from dask.sizeof import sizeof
from dask.typing import Graph, Key, NestedKeys
from dask.utils import (
IndexCallable,
SerializableLock,
cached_cumsum,
cached_property,
concrete,
derived_from,
format_bytes,
funcname,
has_keyword,
is_arraylike,
is_dataframe_like,
is_index_like,
is_integer,
is_series_like,
maybe_pluralize,
ndeepmap,
ndimlist,
parse_bytes,
typename,
)
from dask.widgets import get_template
T_IntOrNaN = Union[int, float] # Should be Union[int, Literal[np.nan]]
DEFAULT_GET = named_schedulers.get("threads", named_schedulers["sync"])
unknown_chunk_message = (
"\n\n"
"A possible solution: "
"https://docs.dask.org/en/latest/array-chunks.html#unknown-chunks\n"
"Summary: to compute chunks sizes, use\n\n"
" x.compute_chunk_sizes() # for Dask Array `x`\n"
" ddf.to_dask_array(lengths=True) # for Dask DataFrame `ddf`"
)
class PerformanceWarning(Warning):
"""A warning given when bad chunking may cause poor performance"""
def getter(a, b, asarray=True, lock=None):
if isinstance(b, tuple) and any(x is None for x in b):
b2 = tuple(x for x in b if x is not None)
b3 = tuple(
None if x is None else slice(None, None)
for x in b
if not isinstance(x, Integral)
)
return getter(a, b2, asarray=asarray, lock=lock)[b3]
if lock:
lock.acquire()
try:
c = a[b]
# Below we special-case `np.matrix` to force a conversion to
# `np.ndarray` and preserve original Dask behavior for `getter`,
# as for all purposes `np.matrix` is array-like and thus
# `is_arraylike` evaluates to `True` in that case.
if asarray and (not is_arraylike(c) or isinstance(c, np.matrix)):
c = np.asarray(c)
finally:
if lock:
lock.release()
return c
def getter_nofancy(a, b, asarray=True, lock=None):
"""A simple wrapper around ``getter``.
Used to indicate to the optimization passes that the backend doesn't
support fancy indexing.
"""
return getter(a, b, asarray=asarray, lock=lock)
def getter_inline(a, b, asarray=True, lock=None):
"""A getter function that optimizations feel comfortable inlining
Slicing operations with this function may be inlined into a graph, such as
in the following rewrite
**Before**
>>> a = x[:10] # doctest: +SKIP
>>> b = a + 1 # doctest: +SKIP
>>> c = a * 2 # doctest: +SKIP
**After**
>>> b = x[:10] + 1 # doctest: +SKIP
>>> c = x[:10] * 2 # doctest: +SKIP
This inlining can be relevant to operations when running off of disk.
"""
return getter(a, b, asarray=asarray, lock=lock)
from dask.array.optimization import fuse_slice, optimize
# __array_function__ dict for mapping aliases and mismatching names
_HANDLED_FUNCTIONS = {}
def implements(*numpy_functions):
"""Register an __array_function__ implementation for dask.array.Array
Register that a function implements the API of a NumPy function (or several
NumPy functions in case of aliases) which is handled with
``__array_function__``.
Parameters
----------
\\*numpy_functions : callables
One or more NumPy functions that are handled by ``__array_function__``
and will be mapped by `implements` to a `dask.array` function.
"""
def decorator(dask_func):
for numpy_function in numpy_functions:
_HANDLED_FUNCTIONS[numpy_function] = dask_func
return dask_func
return decorator
def _should_delegate(self, other) -> bool:
"""Check whether Dask should delegate to the other.
This implementation follows NEP-13:
https://numpy.org/neps/nep-0013-ufunc-overrides.html#behavior-in-combination-with-python-s-binary-operations
"""
if hasattr(other, "__array_ufunc__") and other.__array_ufunc__ is None:
return True
elif (
hasattr(other, "__array_ufunc__")
and not is_valid_array_chunk(other)
# don't delegate to our own parent classes
and not isinstance(self, type(other))
and type(self) is not type(other)
):
return True
return False
def check_if_handled_given_other(f):
"""Check if method is handled by Dask given type of other
Ensures proper deferral to upcast types in dunder operations without
assuming unknown types are automatically downcast types.
"""
@wraps(f)
def wrapper(self, other):
if _should_delegate(self, other):
return NotImplemented
else:
return f(self, other)
return wrapper
def slices_from_chunks(chunks):
"""Translate chunks tuple to a set of slices in product order
>>> slices_from_chunks(((2, 2), (3, 3, 3))) # doctest: +NORMALIZE_WHITESPACE
[(slice(0, 2, None), slice(0, 3, None)),
(slice(0, 2, None), slice(3, 6, None)),
(slice(0, 2, None), slice(6, 9, None)),
(slice(2, 4, None), slice(0, 3, None)),
(slice(2, 4, None), slice(3, 6, None)),
(slice(2, 4, None), slice(6, 9, None))]
"""
cumdims = [cached_cumsum(bds, initial_zero=True) for bds in chunks]
slices = [
[slice(s, s + dim) for s, dim in zip(starts, shapes)]
for starts, shapes in zip(cumdims, chunks)
]
return list(product(*slices))
def graph_from_arraylike(
arr, # Any array-like which supports slicing
chunks,
shape,
name,
getitem=getter,
lock=False,
asarray=True,
dtype=None,
inline_array=False,
) -> HighLevelGraph:
"""
HighLevelGraph for slicing chunks from an array-like according to a chunk pattern.
If ``inline_array`` is True, this make a Blockwise layer of slicing tasks where the
array-like is embedded into every task.,
If ``inline_array`` is False, this inserts the array-like as a standalone value in
a MaterializedLayer, then generates a Blockwise layer of slicing tasks that refer
to it.
>>> dict(graph_from_arraylike(arr, chunks=(2, 3), shape=(4, 6), name="X", inline_array=True)) # doctest: +SKIP
{(arr, 0, 0): (getter, arr, (slice(0, 2), slice(0, 3))),
(arr, 1, 0): (getter, arr, (slice(2, 4), slice(0, 3))),
(arr, 1, 1): (getter, arr, (slice(2, 4), slice(3, 6))),
(arr, 0, 1): (getter, arr, (slice(0, 2), slice(3, 6)))}
>>> dict( # doctest: +SKIP
graph_from_arraylike(arr, chunks=((2, 2), (3, 3)), shape=(4,6), name="X", inline_array=False)
)
{"original-X": arr,
('X', 0, 0): (getter, 'original-X', (slice(0, 2), slice(0, 3))),
('X', 1, 0): (getter, 'original-X', (slice(2, 4), slice(0, 3))),
('X', 1, 1): (getter, 'original-X', (slice(2, 4), slice(3, 6))),
('X', 0, 1): (getter, 'original-X', (slice(0, 2), slice(3, 6)))}
"""
chunks = normalize_chunks(chunks, shape, dtype=dtype)
out_ind = tuple(range(len(shape)))
if (
has_keyword(getitem, "asarray")
and has_keyword(getitem, "lock")
and (not asarray or lock)
):
kwargs = {"asarray": asarray, "lock": lock}
else:
# Common case, drop extra parameters
kwargs = {}
if inline_array:
layer = core_blockwise(
getitem,
name,
out_ind,
arr,
None,
ArraySliceDep(chunks),
out_ind,
numblocks={},
**kwargs,
)
return HighLevelGraph.from_collections(name, layer)
else:
original_name = "original-" + name
layers = {}
layers[original_name] = MaterializedLayer({original_name: arr})
layers[name] = core_blockwise(
getitem,
name,
out_ind,
original_name,
None,
ArraySliceDep(chunks),
out_ind,
numblocks={},
**kwargs,
)
deps = {
original_name: set(),
name: {original_name},
}
return HighLevelGraph(layers, deps)
def dotmany(A, B, leftfunc=None, rightfunc=None, **kwargs):
"""Dot product of many aligned chunks
>>> x = np.array([[1, 2], [1, 2]])
>>> y = np.array([[10, 20], [10, 20]])
>>> dotmany([x, x, x], [y, y, y])
array([[ 90, 180],
[ 90, 180]])
Optionally pass in functions to apply to the left and right chunks
>>> dotmany([x, x, x], [y, y, y], rightfunc=np.transpose)
array([[150, 150],
[150, 150]])
"""
if leftfunc:
A = map(leftfunc, A)
if rightfunc:
B = map(rightfunc, B)
return sum(map(partial(np.dot, **kwargs), A, B))
def _concatenate2(arrays, axes=None):
"""Recursively concatenate nested lists of arrays along axes
Each entry in axes corresponds to each level of the nested list. The
length of axes should correspond to the level of nesting of arrays.
If axes is an empty list or tuple, return arrays, or arrays[0] if
arrays is a list.
>>> x = np.array([[1, 2], [3, 4]])
>>> _concatenate2([x, x], axes=[0])
array([[1, 2],
[3, 4],
[1, 2],
[3, 4]])
>>> _concatenate2([x, x], axes=[1])
array([[1, 2, 1, 2],
[3, 4, 3, 4]])
>>> _concatenate2([[x, x], [x, x]], axes=[0, 1])
array([[1, 2, 1, 2],
[3, 4, 3, 4],
[1, 2, 1, 2],
[3, 4, 3, 4]])
Supports Iterators
>>> _concatenate2(iter([x, x]), axes=[1])
array([[1, 2, 1, 2],
[3, 4, 3, 4]])
Special Case
>>> _concatenate2([x, x], axes=())
array([[1, 2],
[3, 4]])
"""
if axes is None:
axes = []
if axes == ():
if isinstance(arrays, list):
return arrays[0]
else:
return arrays
if isinstance(arrays, Iterator):
arrays = list(arrays)
if not isinstance(arrays, (list, tuple)):
return arrays
if len(axes) > 1:
arrays = [_concatenate2(a, axes=axes[1:]) for a in arrays]
concatenate = concatenate_lookup.dispatch(
type(max(arrays, key=lambda x: getattr(x, "__array_priority__", 0)))
)
if isinstance(arrays[0], dict):
# Handle concatenation of `dict`s, used as a replacement for structured
# arrays when that's not supported by the array library (e.g., CuPy).
keys = list(arrays[0].keys())
assert all(list(a.keys()) == keys for a in arrays)
ret = dict()
for k in keys:
ret[k] = concatenate(list(a[k] for a in arrays), axis=axes[0])
return ret
else:
return concatenate(arrays, axis=axes[0])
def apply_infer_dtype(func, args, kwargs, funcname, suggest_dtype="dtype", nout=None):
"""
Tries to infer output dtype of ``func`` for a small set of input arguments.
Parameters
----------
func: Callable
Function for which output dtype is to be determined
args: List of array like
Arguments to the function, which would usually be used. Only attributes
``ndim`` and ``dtype`` are used.
kwargs: dict
Additional ``kwargs`` to the ``func``
funcname: String
Name of calling function to improve potential error messages
suggest_dtype: None/False or String
If not ``None`` adds suggestion to potential error message to specify a dtype
via the specified kwarg. Defaults to ``'dtype'``.
nout: None or Int
``None`` if function returns single output, integer if many.
Defaults to ``None``.
Returns
-------
: dtype or List of dtype
One or many dtypes (depending on ``nout``)
"""
from dask.array.utils import meta_from_array
# make sure that every arg is an evaluated array
args = [
(
np.ones_like(meta_from_array(x), shape=((1,) * x.ndim), dtype=x.dtype)
if is_arraylike(x)
else x
)
for x in args
]
try:
with np.errstate(all="ignore"):
o = func(*args, **kwargs)
except Exception as e:
exc_type, exc_value, exc_traceback = sys.exc_info()
tb = "".join(traceback.format_tb(exc_traceback))
suggest = (
(
"Please specify the dtype explicitly using the "
"`{dtype}` kwarg.\n\n".format(dtype=suggest_dtype)
)
if suggest_dtype
else ""
)
msg = (
f"`dtype` inference failed in `{funcname}`.\n\n"
f"{suggest}"
"Original error is below:\n"
"------------------------\n"
f"{e!r}\n\n"
"Traceback:\n"
"---------\n"
f"{tb}"
)
else:
msg = None
if msg is not None:
raise ValueError(msg)
return getattr(o, "dtype", type(o)) if nout is None else tuple(e.dtype for e in o)
def normalize_arg(x):
"""Normalize user provided arguments to blockwise or map_blocks
We do a few things:
1. If they are string literals that might collide with blockwise_token then we
quote them
2. IF they are large (as defined by sizeof) then we put them into the
graph on their own by using dask.delayed
"""
if is_dask_collection(x):
return x
elif isinstance(x, str) and re.match(r"_\d+", x):
return delayed(x)
elif isinstance(x, list) and len(x) >= 10:
return delayed(x)
elif sizeof(x) > 1e6:
return delayed(x)
else:
return x
def _pass_extra_kwargs(func, keys, *args, **kwargs):
"""Helper for :func:`dask.array.map_blocks` to pass `block_info` or `block_id`.
For each element of `keys`, a corresponding element of args is changed
to a keyword argument with that key, before all arguments re passed on
to `func`.
"""
kwargs.update(zip(keys, args))
return func(*args[len(keys) :], **kwargs)
def map_blocks(
func,
*args,
name=None,
token=None,
dtype=None,
chunks=None,
drop_axis=None,
new_axis=None,
enforce_ndim=False,
meta=None,
**kwargs,
):
"""Map a function across all blocks of a dask array.
Note that ``map_blocks`` will attempt to automatically determine the output
array type by calling ``func`` on 0-d versions of the inputs. Please refer to
the ``meta`` keyword argument below if you expect that the function will not
succeed when operating on 0-d arrays.
Parameters
----------
func : callable
Function to apply to every block in the array.
If ``func`` accepts ``block_info=`` or ``block_id=``
as keyword arguments, these will be passed dictionaries
containing information about input and output chunks/arrays
during computation. See examples for details.
args : dask arrays or other objects
dtype : np.dtype, optional
The ``dtype`` of the output array. It is recommended to provide this.
If not provided, will be inferred by applying the function to a small
set of fake data.
chunks : tuple, optional
Chunk shape of resulting blocks if the function does not preserve
shape. If not provided, the resulting array is assumed to have the same
block structure as the first input array.
drop_axis : number or iterable, optional
Dimensions lost by the function.
new_axis : number or iterable, optional
New dimensions created by the function. Note that these are applied
after ``drop_axis`` (if present). The size of each chunk along this
dimension will be set to 1. Please specify ``chunks`` if the individual
chunks have a different size.
enforce_ndim : bool, default False
Whether to enforce at runtime that the dimensionality of the array
produced by ``func`` actually matches that of the array returned by
``map_blocks``.
If True, this will raise an error when there is a mismatch.
token : string, optional
The key prefix to use for the output array. If not provided, will be
determined from the function name.
name : string, optional
The key name to use for the output array. Note that this fully
specifies the output key name, and must be unique. If not provided,
will be determined by a hash of the arguments.
meta : array-like, optional
The ``meta`` of the output array, when specified is expected to be an
array of the same type and dtype of that returned when calling ``.compute()``
on the array returned by this function. When not provided, ``meta`` will be
inferred by applying the function to a small set of fake data, usually a
0-d array. It's important to ensure that ``func`` can successfully complete
computation without raising exceptions when 0-d is passed to it, providing
``meta`` will be required otherwise. If the output type is known beforehand
(e.g., ``np.ndarray``, ``cupy.ndarray``), an empty array of such type dtype
can be passed, for example: ``meta=np.array((), dtype=np.int32)``.
**kwargs :
Other keyword arguments to pass to function. Values must be constants
(not dask.arrays)
See Also
--------
dask.array.map_overlap : Generalized operation with overlap between neighbors.
dask.array.blockwise : Generalized operation with control over block alignment.
Examples
--------
>>> import dask.array as da
>>> x = da.arange(6, chunks=3)
>>> x.map_blocks(lambda x: x * 2).compute()
array([ 0, 2, 4, 6, 8, 10])
The ``da.map_blocks`` function can also accept multiple arrays.
>>> d = da.arange(5, chunks=2)
>>> e = da.arange(5, chunks=2)
>>> f = da.map_blocks(lambda a, b: a + b**2, d, e)
>>> f.compute()
array([ 0, 2, 6, 12, 20])
If the function changes shape of the blocks then you must provide chunks
explicitly.
>>> y = x.map_blocks(lambda x: x[::2], chunks=((2, 2),))
You have a bit of freedom in specifying chunks. If all of the output chunk
sizes are the same, you can provide just that chunk size as a single tuple.
>>> a = da.arange(18, chunks=(6,))
>>> b = a.map_blocks(lambda x: x[:3], chunks=(3,))
If the function changes the dimension of the blocks you must specify the
created or destroyed dimensions.
>>> b = a.map_blocks(lambda x: x[None, :, None], chunks=(1, 6, 1),
... new_axis=[0, 2])
If ``chunks`` is specified but ``new_axis`` is not, then it is inferred to
add the necessary number of axes on the left.
Note that ``map_blocks()`` will concatenate chunks along axes specified by
the keyword parameter ``drop_axis`` prior to applying the function.
This is illustrated in the figure below:
.. image:: /images/map_blocks_drop_axis.png
Due to memory-size-constraints, it is often not advisable to use ``drop_axis``
on an axis that is chunked. In that case, it is better not to use
``map_blocks`` but rather
``dask.array.reduction(..., axis=dropped_axes, concatenate=False)`` which
maintains a leaner memory footprint while it drops any axis.
Map_blocks aligns blocks by block positions without regard to shape. In the
following example we have two arrays with the same number of blocks but
with different shape and chunk sizes.
>>> x = da.arange(1000, chunks=(100,))
>>> y = da.arange(100, chunks=(10,))
The relevant attribute to match is numblocks.
>>> x.numblocks
(10,)
>>> y.numblocks
(10,)
If these match (up to broadcasting rules) then we can map arbitrary
functions across blocks
>>> def func(a, b):
... return np.array([a.max(), b.max()])
>>> da.map_blocks(func, x, y, chunks=(2,), dtype='i8')
dask.array<func, shape=(20,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray>
>>> _.compute()
array([ 99, 9, 199, 19, 299, 29, 399, 39, 499, 49, 599, 59, 699,
69, 799, 79, 899, 89, 999, 99])
Your block function can get information about where it is in the array by
accepting a special ``block_info`` or ``block_id`` keyword argument.
During computation, they will contain information about each of the input
and output chunks (and dask arrays) relevant to each call of ``func``.
>>> def func(block_info=None):
... pass
This will receive the following information:
>>> block_info # doctest: +SKIP
{0: {'shape': (1000,),
'num-chunks': (10,),
'chunk-location': (4,),
'array-location': [(400, 500)]},
None: {'shape': (1000,),
'num-chunks': (10,),
'chunk-location': (4,),
'array-location': [(400, 500)],
'chunk-shape': (100,),
'dtype': dtype('float64')}}
The keys to the ``block_info`` dictionary indicate which is the input and
output Dask array:
- **Input Dask array(s):** ``block_info[0]`` refers to the first input Dask array.
The dictionary key is ``0`` because that is the argument index corresponding
to the first input Dask array.
In cases where multiple Dask arrays have been passed as input to the function,
you can access them with the number corresponding to the input argument,
eg: ``block_info[1]``, ``block_info[2]``, etc.
(Note that if you pass multiple Dask arrays as input to map_blocks,
the arrays must match each other by having matching numbers of chunks,
along corresponding dimensions up to broadcasting rules.)
- **Output Dask array:** ``block_info[None]`` refers to the output Dask array,
and contains information about the output chunks.
The output chunk shape and dtype may may be different than the input chunks.
For each dask array, ``block_info`` describes:
- ``shape``: the shape of the full Dask array,
- ``num-chunks``: the number of chunks of the full array in each dimension,
- ``chunk-location``: the chunk location (for example the fourth chunk over
in the first dimension), and
- ``array-location``: the array location within the full Dask array
(for example the slice corresponding to ``40:50``).
In addition to these, there are two extra parameters described by
``block_info`` for the output array (in ``block_info[None]``):
- ``chunk-shape``: the output chunk shape, and
- ``dtype``: the output dtype.
These features can be combined to synthesize an array from scratch, for
example:
>>> def func(block_info=None):
... loc = block_info[None]['array-location'][0]
... return np.arange(loc[0], loc[1])
>>> da.map_blocks(func, chunks=((4, 4),), dtype=np.float64)
dask.array<func, shape=(8,), dtype=float64, chunksize=(4,), chunktype=numpy.ndarray>
>>> _.compute()
array([0, 1, 2, 3, 4, 5, 6, 7])
``block_id`` is similar to ``block_info`` but contains only the ``chunk_location``:
>>> def func(block_id=None):
... pass
This will receive the following information:
>>> block_id # doctest: +SKIP
(4, 3)
You may specify the key name prefix of the resulting task in the graph with
the optional ``token`` keyword argument.
>>> x.map_blocks(lambda x: x + 1, name='increment')
dask.array<increment, shape=(1000,), dtype=int64, chunksize=(100,), chunktype=numpy.ndarray>
For functions that may not handle 0-d arrays, it's also possible to specify
``meta`` with an empty array matching the type of the expected result. In
the example below, ``func`` will result in an ``IndexError`` when computing
``meta``:
>>> rng = da.random.default_rng()
>>> da.map_blocks(lambda x: x[2], rng.random(5), meta=np.array(()))
dask.array<lambda, shape=(5,), dtype=float64, chunksize=(5,), chunktype=numpy.ndarray>
Similarly, it's possible to specify a non-NumPy array to ``meta``, and provide
a ``dtype``:
>>> import cupy # doctest: +SKIP
>>> rng = da.random.default_rng(cupy.random.default_rng()) # doctest: +SKIP
>>> dt = np.float32
>>> da.map_blocks(lambda x: x[2], rng.random(5, dtype=dt), meta=cupy.array((), dtype=dt)) # doctest: +SKIP
dask.array<lambda, shape=(5,), dtype=float32, chunksize=(5,), chunktype=cupy.ndarray>
"""
if drop_axis is None:
drop_axis = []
if not callable(func):
msg = (
"First argument must be callable function, not %s\n"
"Usage: da.map_blocks(function, x)\n"
" or: da.map_blocks(function, x, y, z)"
)
raise TypeError(msg % type(func).__name__)
if token:
warnings.warn(
"The `token=` keyword to `map_blocks` has been moved to `name=`. "
"Please use `name=` instead as the `token=` keyword will be removed "
"in a future release.",
category=FutureWarning,
)
name = token
name = f"{name or funcname(func)}-{tokenize(func, dtype, chunks, drop_axis, new_axis, *args, **kwargs)}"
new_axes = {}
if isinstance(drop_axis, Number):
drop_axis = [drop_axis]
if isinstance(new_axis, Number):
new_axis = [new_axis] # TODO: handle new_axis
arrs = [a for a in args if isinsta…tack
"""
from dask.array import wrap
seq = [asarray(a, allow_unknown_chunksizes=allow_unknown_chunksizes) for a in seq]
if not seq:
raise ValueError("Need array(s) to concatenate")
if axis is None:
seq = [a.flatten() for a in seq]
axis = 0
seq_metas = [meta_from_array(s) for s in seq]
_concatenate = concatenate_lookup.dispatch(
type(max(seq_metas, key=lambda x: getattr(x, "__array_priority__", 0)))
)
meta = _concatenate(seq_metas, axis=axis)
# Promote types to match meta
seq = [a.astype(meta.dtype) for a in seq]
# Find output array shape
ndim = len(seq[0].shape)
shape = tuple(
sum(a.shape[i] for a in seq) if i == axis else seq[0].shape[i]
for i in range(ndim)
)
# Drop empty arrays
seq2 = [a for a in seq if a.size]
if not seq2:
seq2 = seq
if axis < 0:
axis = ndim + axis
if axis >= ndim:
msg = (
"Axis must be less than than number of dimensions"
"\nData has %d dimensions, but got axis=%d"
)
raise ValueError(msg % (ndim, axis))
n = len(seq2)
if n == 0:
try:
return wrap.empty_like(meta, shape=shape, chunks=shape, dtype=meta.dtype)
except TypeError:
return wrap.empty(shape, chunks=shape, dtype=meta.dtype)
elif n == 1:
return seq2[0]
if not allow_unknown_chunksizes and not all(
i == axis or all(x.shape[i] == seq2[0].shape[i] for x in seq2)
for i in range(ndim)
):
if any(map(np.isnan, seq2[0].shape)):
raise ValueError(
"Tried to concatenate arrays with unknown"
" shape %s.\n\nTwo solutions:\n"
" 1. Force concatenation pass"
" allow_unknown_chunksizes=True.\n"
" 2. Compute shapes with "
"[x.compute_chunk_sizes() for x in seq]" % str(seq2[0].shape)
)
raise ValueError("Shapes do not align: %s", [x.shape for x in seq2])
inds = [list(range(ndim)) for i in range(n)]
for i, ind in enumerate(inds):
ind[axis] = -(i + 1)
uc_args = list(concat(zip(seq2, inds)))
_, seq2 = unify_chunks(*uc_args, warn=False)
bds = [a.chunks for a in seq2]
chunks = (
seq2[0].chunks[:axis]
+ (sum((bd[axis] for bd in bds), ()),)
+ seq2[0].chunks[axis + 1 :]
)
cum_dims = [0] + list(accumulate(add, [len(a.chunks[axis]) for a in seq2]))
names = [a.name for a in seq2]
name = "concatenate-" + tokenize(names, axis)
keys = list(product([name], *[range(len(bd)) for bd in chunks]))
values = [
(names[bisect(cum_dims, key[axis + 1]) - 1],)
+ key[1 : axis + 1]
+ (key[axis + 1] - cum_dims[bisect(cum_dims, key[axis + 1]) - 1],)
+ key[axis + 2 :]
for key in keys
]
dsk = dict(zip(keys, values))
graph = HighLevelGraph.from_collections(name, dsk, dependencies=seq2)
return Array(graph, name, chunks, meta=meta)
def load_store_chunk(
x: Any,
out: Any,
index: slice,
lock: Any,
return_stored: bool,
load_stored: bool,
):
"""
A function inserted in a Dask graph for storing a chunk.
Parameters
----------
x: array-like
An array (potentially a NumPy one)
out: array-like
Where to store results.
index: slice-like
Where to store result from ``x`` in ``out``.
lock: Lock-like or False
Lock to use before writing to ``out``.
return_stored: bool
Whether to return ``out``.
load_stored: bool
Whether to return the array stored in ``out``.
Ignored if ``return_stored`` is not ``True``.
Returns
-------
If return_stored=True and load_stored=False
out
If return_stored=True and load_stored=True
out[index]
If return_stored=False and compute=False
None
Examples
--------
>>> a = np.ones((5, 6))
>>> b = np.empty(a.shape)
>>> load_store_chunk(a, b, (slice(None), slice(None)), False, False, False)
"""
if lock:
lock.acquire()
try:
if x is not None and x.size != 0:
if is_arraylike(x):
out[index] = x
else:
out[index] = np.asanyarray(x)
if return_stored and load_stored:
return out[index]
elif return_stored and not load_stored:
return out
else:
return None
finally:
if lock:
lock.release()
def store_chunk(
x: ArrayLike, out: ArrayLike, index: slice, lock: Any, return_stored: bool
):
return load_store_chunk(x, out, index, lock, return_stored, False)
A = TypeVar("A", bound=ArrayLike)
def load_chunk(out: A, index: slice, lock: Any) -> A:
return load_store_chunk(None, out, index, lock, True, True)
def insert_to_ooc(
keys: list,
chunks: tuple[tuple[int, ...], ...],
out: ArrayLike,
name: str,
*,
lock: Lock | bool = True,
region: tuple[slice, ...] | slice | None = None,
return_stored: bool = False,
load_stored: bool = False,
) -> dict:
"""
Creates a Dask graph for storing chunks from ``arr`` in ``out``.
Parameters
----------
keys: list
Dask keys of the input array
chunks: tuple
Dask chunks of the input array
out: array-like
Where to store results to
name: str
First element of dask keys
lock: Lock-like or bool, optional
Whether to lock or with what (default is ``True``,
which means a :class:`threading.Lock` instance).
region: slice-like, optional
Where in ``out`` to store ``arr``'s results
(default is ``None``, meaning all of ``out``).
return_stored: bool, optional
Whether to return ``out``
(default is ``False``, meaning ``None`` is returned).
load_stored: bool, optional
Whether to handling loading from ``out`` at the same time.
Ignored if ``return_stored`` is not ``True``.
(default is ``False``, meaning defer to ``return_stored``).
Returns
-------
dask graph of store operation
Examples
--------
>>> import dask.array as da
>>> d = da.ones((5, 6), chunks=(2, 3))
>>> a = np.empty(d.shape)
>>> insert_to_ooc(d.__dask_keys__(), d.chunks, a, "store-123") # doctest: +SKIP
"""
if lock is True:
lock = Lock()
slices = slices_from_chunks(chunks)
if region:
slices = [fuse_slice(region, slc) for slc in slices]
if return_stored and load_stored:
func = load_store_chunk
args = (load_stored,)
else:
func = store_chunk # type: ignore
args = () # type: ignore
dsk = {
(name,) + t[1:]: (func, t, out, slc, lock, return_stored) + args
for t, slc in zip(core.flatten(keys), slices)
}
return dsk
def retrieve_from_ooc(
keys: Collection[Key], dsk_pre: Graph, dsk_post: Graph
) -> dict[tuple, Any]:
"""
Creates a Dask graph for loading stored ``keys`` from ``dsk``.
Parameters
----------
keys: Collection
A sequence containing Dask graph keys to load
dsk_pre: Mapping
A Dask graph corresponding to a Dask Array before computation
dsk_post: Mapping
A Dask graph corresponding to a Dask Array after computation
Examples
--------
>>> import dask.array as da
>>> d = da.ones((5, 6), chunks=(2, 3))
>>> a = np.empty(d.shape)
>>> g = insert_to_ooc(d.__dask_keys__(), d.chunks, a, "store-123")
>>> retrieve_from_ooc(g.keys(), g, {k: k for k in g.keys()}) # doctest: +SKIP
"""
load_dsk = {
("load-" + k[0],) + k[1:]: (load_chunk, dsk_post[k]) + dsk_pre[k][3:-1] # type: ignore
for k in keys
}
return load_dsk
def _as_dtype(a, dtype):
if dtype is None:
return a
else:
return a.astype(dtype)
def asarray(
a, allow_unknown_chunksizes=False, dtype=None, order=None, *, like=None, **kwargs
):
"""Convert the input to a dask array.
Parameters
----------
a : array-like
Input data, in any form that can be converted to a dask array. This
includes lists, lists of tuples, tuples, tuples of tuples, tuples of
lists and ndarrays.
allow_unknown_chunksizes: bool
Allow unknown chunksizes, such as come from converting from dask
dataframes. Dask.array is unable to verify that chunks line up. If
data comes from differently aligned sources then this can cause
unexpected results.
dtype : data-type, optional
By default, the data-type is inferred from the input data.
order : {‘C’, ‘F’, ‘A’, ‘K’}, optional
Memory layout. ‘A’ and ‘K’ depend on the order of input array a.
‘C’ row-major (C-style), ‘F’ column-major (Fortran-style) memory
representation. ‘A’ (any) means ‘F’ if a is Fortran contiguous, ‘C’
otherwise ‘K’ (keep) preserve input order. Defaults to ‘C’.
like: array-like
Reference object to allow the creation of Dask arrays with chunks
that are not NumPy arrays. If an array-like passed in as ``like``
supports the ``__array_function__`` protocol, the chunk type of the
resulting array will be defined by it. In this case, it ensures the
creation of a Dask array compatible with that passed in via this
argument. If ``like`` is a Dask array, the chunk type of the
resulting array will be defined by the chunk type of ``like``.
Requires NumPy 1.20.0 or higher.
Returns
-------
out : dask array
Dask array interpretation of a.
Examples
--------
>>> import dask.array as da
>>> import numpy as np
>>> x = np.arange(3)
>>> da.asarray(x)
dask.array<array, shape=(3,), dtype=int64, chunksize=(3,), chunktype=numpy.ndarray>
>>> y = [[1, 2, 3], [4, 5, 6]]
>>> da.asarray(y)
dask.array<array, shape=(2, 3), dtype=int64, chunksize=(2, 3), chunktype=numpy.ndarray>
.. warning::
`order` is ignored if `a` is an `Array`, has the attribute ``to_dask_array``,
or is a list or tuple of `Array`'s.
"""
if like is None:
if isinstance(a, Array):
return _as_dtype(a, dtype)
elif hasattr(a, "to_dask_array"):
return _as_dtype(a.to_dask_array(), dtype)
elif type(a).__module__.split(".")[0] == "xarray" and hasattr(a, "data"):
return _as_dtype(asarray(a.data, order=order), dtype)
elif isinstance(a, (list, tuple)) and any(isinstance(i, Array) for i in a):
return _as_dtype(
stack(a, allow_unknown_chunksizes=allow_unknown_chunksizes), dtype
)
elif not isinstance(getattr(a, "shape", None), Iterable):
a = np.asarray(a, dtype=dtype, order=order)
else:
like_meta = meta_from_array(like)
if isinstance(a, Array):
return a.map_blocks(np.asarray, like=like_meta, dtype=dtype, order=order)
else:
a = np.asarray(a, like=like_meta, dtype=dtype, order=order)
return from_array(a, getitem=getter_inline, **kwargs)
def asanyarray(a, dtype=None, order=None, *, like=None, inline_array=False):
"""Convert the input to a dask array.
Subclasses of ``np.ndarray`` will be passed through as chunks unchanged.
Parameters
----------
a : array-like
Input data, in any form that can be converted to a dask array. This
includes lists, lists of tuples, tuples, tuples of tuples, tuples of
lists and ndarrays.
dtype : data-type, optional
By default, the data-type is inferred from the input data.
order : {‘C’, ‘F’, ‘A’, ‘K’}, optional
Memory layout. ‘A’ and ‘K’ depend on the order of input array a.
‘C’ row-major (C-style), ‘F’ column-major (Fortran-style) memory
representation. ‘A’ (any) means ‘F’ if a is Fortran contiguous, ‘C’
otherwise ‘K’ (keep) preserve input order. Defaults to ‘C’.
like: array-like
Reference object to allow the creation of Dask arrays with chunks
that are not NumPy arrays. If an array-like passed in as ``like``
supports the ``__array_function__`` protocol, the chunk type of the
resulting array will be defined by it. In this case, it ensures the
creation of a Dask array compatible with that passed in via this
argument. If ``like`` is a Dask array, the chunk type of the
resulting array will be defined by the chunk type of ``like``.
Requires NumPy 1.20.0 or higher.
inline_array:
Whether to inline the array in the resulting dask graph. For more information,
see the documentation for ``dask.array.from_array()``.
Returns
-------
out : dask array
Dask array interpretation of a.
Examples
--------
>>> import dask.array as da
>>> import numpy as np
>>> x = np.arange(3)
>>> da.asanyarray(x)
dask.array<array, shape=(3,), dtype=int64, chunksize=(3,), chunktype=numpy.ndarray>
>>> y = [[1, 2, 3], [4, 5, 6]]
>>> da.asanyarray(y)
dask.array<array, shape=(2, 3), dtype=int64, chunksize=(2, 3), chunktype=numpy.ndarray>
.. warning::
`order` is ignored if `a` is an `Array`, has the attribute ``to_dask_array``,
or is a list or tuple of `Array`'s.
"""
if like is None:
if isinstance(a, Array):
return _as_dtype(a, dtype)
elif hasattr(a, "to_dask_array"):
return _as_dtype(a.to_dask_array(), dtype)
elif type(a).__module__.split(".")[0] == "xarray" and hasattr(a, "data"):
return _as_dtype(asarray(a.data, order=order), dtype)
elif isinstance(a, (list, tuple)) and any(isinstance(i, Array) for i in a):
return _as_dtype(stack(a), dtype)
elif not isinstance(getattr(a, "shape", None), Iterable):
a = np.asanyarray(a, dtype=dtype, order=order)
else:
like_meta = meta_from_array(like)
if isinstance(a, Array):
return a.map_blocks(np.asanyarray, like=like_meta, dtype=dtype, order=order)
else:
a = np.asanyarray(a, like=like_meta, dtype=dtype, order=order)
return from_array(
a,
chunks=a.shape,
getitem=getter_inline,
asarray=False,
inline_array=inline_array,
)
def is_scalar_for_elemwise(arg):
"""
>>> is_scalar_for_elemwise(42)
True
>>> is_scalar_for_elemwise('foo')
True
>>> is_scalar_for_elemwise(True)
True
>>> is_scalar_for_elemwise(np.array(42))
True
>>> is_scalar_for_elemwise([1, 2, 3])
True
>>> is_scalar_for_elemwise(np.array([1, 2, 3]))
False
>>> is_scalar_for_elemwise(from_array(np.array(0), chunks=()))
False
>>> is_scalar_for_elemwise(np.dtype('i4'))
True
"""
# the second half of shape_condition is essentially just to ensure that
# dask series / frame are treated as scalars in elemwise.
maybe_shape = getattr(arg, "shape", None)
shape_condition = not isinstance(maybe_shape, Iterable) or any(
is_dask_collection(x) for x in maybe_shape
)
return (
np.isscalar(arg)
or shape_condition
or isinstance(arg, np.dtype)
or (isinstance(arg, np.ndarray) and arg.ndim == 0)
)
def broadcast_shapes(*shapes):
"""
Determines output shape from broadcasting arrays.
Parameters
----------
shapes : tuples
The shapes of the arguments.
Returns
-------
output_shape : tuple
Raises
------
ValueError
If the input shapes cannot be successfully broadcast together.
"""
if len(shapes) == 1:
return shapes[0]
out = []
for sizes in zip_longest(*map(reversed, shapes), fillvalue=-1):
if np.isnan(sizes).any():
dim = np.nan
else:
dim = 0 if 0 in sizes else np.max(sizes).item()
if any(i not in [-1, 0, 1, dim] and not np.isnan(i) for i in sizes):
raise ValueError(
"operands could not be broadcast together with "
"shapes {}".format(" ".join(map(str, shapes)))
)
out.append(dim)
return tuple(reversed(out))
def elemwise(op, *args, out=None, where=True, dtype=None, name=None, **kwargs):
"""Apply an elementwise ufunc-like function blockwise across arguments.
Like numpy ufuncs, broadcasting rules are respected.
Parameters
----------
op : callable
The function to apply. Should be numpy ufunc-like in the parameters
that it accepts.
*args : Any
Arguments to pass to `op`. Non-dask array-like objects are first
converted to dask arrays, then all arrays are broadcast together before
applying the function blockwise across all arguments. Any scalar
arguments are passed as-is following normal numpy ufunc behavior.
out : dask array, optional
If out is a dask.array then this overwrites the contents of that array
with the result.
where : array_like, optional
An optional boolean mask marking locations where the ufunc should be
applied. Can be a scalar, dask array, or any other array-like object.
Mirrors the ``where`` argument to numpy ufuncs, see e.g. ``numpy.add``
for more information.
dtype : dtype, optional
If provided, overrides the output array dtype.
name : str, optional
A unique key name to use when building the backing dask graph. If not
provided, one will be automatically generated based on the input
arguments.
Examples
--------
>>> elemwise(add, x, y) # doctest: +SKIP
>>> elemwise(sin, x) # doctest: +SKIP
>>> elemwise(sin, x, out=dask_array) # doctest: +SKIP
See Also
--------
blockwise
"""
if kwargs:
raise TypeError(
f"{op.__name__} does not take the following keyword arguments "
f"{sorted(kwargs)}"
)
out = _elemwise_normalize_out(out)
where = _elemwise_normalize_where(where)
args = [np.asarray(a) if isinstance(a, (list, tuple)) else a for a in args]
shapes = []
for arg in args:
shape = getattr(arg, "shape", ())
if any(is_dask_collection(x) for x in shape):
# Want to exclude Delayed shapes and dd.Scalar
shape = ()
shapes.append(shape)
if isinstance(where, Array):
shapes.append(where.shape)
if isinstance(out, Array):
shapes.append(out.shape)
shapes = [s if isinstance(s, Iterable) else () for s in shapes]
out_ndim = len(
broadcast_shapes(*shapes)
) # Raises ValueError if dimensions mismatch
expr_inds = tuple(range(out_ndim))[::-1]
if dtype is not None:
need_enforce_dtype = True
else:
# We follow NumPy's rules for dtype promotion, which special cases
# scalars and 0d ndarrays (which it considers equivalent) by using
# their values to compute the result dtype:
# https://github.com/numpy/numpy/issues/6240
# We don't inspect the values of 0d dask arrays, because these could
# hold potentially very expensive calculations. Instead, we treat
# them just like other arrays, and if necessary cast the result of op
# to match.
vals = [
(
np.empty((1,) * max(1, a.ndim), dtype=a.dtype)
if not is_scalar_for_elemwise(a)
else a
)
for a in args
]
try:
dtype = apply_infer_dtype(op, vals, {}, "elemwise", suggest_dtype=False)
except Exception:
return NotImplemented
need_enforce_dtype = any(
not is_scalar_for_elemwise(a) and a.ndim == 0 for a in args
)
if not name:
name = f"{funcname(op)}-{tokenize(op, dtype, *args, where)}"
blockwise_kwargs = dict(dtype=dtype, name=name, token=funcname(op).strip("_"))
if where is not True:
blockwise_kwargs["elemwise_where_function"] = op
op = _elemwise_handle_where
args.extend([where, out])
if need_enforce_dtype:
blockwise_kwargs["enforce_dtype"] = dtype
blockwise_kwargs["enforce_dtype_function"] = op
op = _enforce_dtype
result = blockwise(
op,
expr_inds,
*concat(
(a, tuple(range(a.ndim)[::-1]) if not is_scalar_for_elemwise(a) else None)
for a in args
),
**blockwise_kwargs,
)
return handle_out(out, result)
def _elemwise_normalize_where(where):
if where is True:
return True
elif where is False or where is None:
return False
return asarray(where)
def _elemwise_handle_where(*args, **kwargs):
function = kwargs.pop("elemwise_where_function")
*args, where, out = args
if hasattr(out, "copy"):
out = out.copy()
return function(*args, where=where, out=out, **kwargs)
def _elemwise_normalize_out(out):
if isinstance(out, tuple):
if len(out) == 1:
out = out[0]
elif len(out) > 1:
raise NotImplementedError("The out parameter is not fully supported")
else:
out = None
if not (out is None or isinstance(out, Array)):
raise NotImplementedError(
f"The out parameter is not fully supported."
f" Received type {type(out).__name__}, expected Dask Array"
)
return out
def handle_out(out, result):
"""Handle out parameters
If out is a dask.array then this overwrites the contents of that array with
the result
"""
out = _elemwise_normalize_out(out)
if isinstance(out, Array):
if out.shape != result.shape:
raise ValueError(
"Mismatched shapes between result and out parameter. "
"out=%s, result=%s" % (str(out.shape), str(result.shape))
)
out._chunks = result.chunks
out.dask = result.dask
out._meta = result._meta
out._name = result.name
return out
else:
return result
def _enforce_dtype(*args, **kwargs):
"""Calls a function and converts its result to the given dtype.
The parameters have deliberately been given unwieldy names to avoid
clashes with keyword arguments consumed by blockwise
A dtype of `object` is treated as a special case and not enforced,
because it is used as a dummy value in some places when the result will
not be a block in an Array.
Parameters
----------
enforce_dtype : dtype
Result dtype
enforce_dtype_function : callable
The wrapped function, which will be passed the remaining arguments
"""
dtype = kwargs.pop("enforce_dtype")
function = kwargs.pop("enforce_dtype_function")
result = function(*args, **kwargs)
if hasattr(result, "dtype") and dtype != result.dtype and dtype != object:
if not np.can_cast(result, dtype, casting="same_kind"):
raise ValueError(
"Inferred dtype from function %r was %r "
"but got %r, which can't be cast using "
"casting='same_kind'"
% (funcname(function), str(dtype), str(result.dtype))
)
if np.isscalar(result):
# scalar astype method doesn't take the keyword arguments, so
# have to convert via 0-dimensional array and back.
result = result.astype(dtype)
else:
try:
result = result.astype(dtype, copy=False)
except TypeError:
# Missing copy kwarg
result = result.astype(dtype)
return result
def broadcast_to(x, shape, chunks=None, meta=None):
"""Broadcast an array to a new shape.
Parameters
----------
x : array_like
The array to broadcast.
shape : tuple
The shape of the desired array.
chunks : tuple, optional
If provided, then the result will use these chunks instead of the same
chunks as the source array. Setting chunks explicitly as part of
broadcast_to is more efficient than rechunking afterwards. Chunks are
only allowed to differ from the original shape along dimensions that
are new on the result or have size 1 the input array.
meta : empty ndarray
empty ndarray created with same NumPy backend, ndim and dtype as the
Dask Array being created (overrides dtype)
Returns
-------
broadcast : dask array
See Also
--------
:func:`numpy.broadcast_to`
"""
x = asarray(x)
shape = tuple(shape)
if meta is None:
meta = meta_from_array(x)
if x.shape == shape and (chunks is None or chunks == x.chunks):
return x
ndim_new = len(shape) - x.ndim
if ndim_new < 0 or any(
new != old for new, old in zip(shape[ndim_new:], x.shape) if old != 1
):
raise ValueError(f"cannot broadcast shape {x.shape} to shape {shape}")
if chunks is None:
chunks = tuple((s,) for s in shape[:ndim_new]) + tuple(
bd if old > 1 else (new,)
for bd, old, new in zip(x.chunks, x.shape, shape[ndim_new:])
)
else:
chunks = normalize_chunks(
chunks, shape, dtype=x.dtype, previous_chunks=x.chunks
)
for old_bd, new_bd in zip(x.chunks, chunks[ndim_new:]):
if old_bd != new_bd and old_bd != (1,):
raise ValueError(
"cannot broadcast chunks %s to chunks %s: "
"new chunks must either be along a new "
"dimension or a dimension of size 1" % (x.chunks, chunks)
)
name = "broadcast_to-" + tokenize(x, shape, chunks)
dsk = {}
enumerated_chunks = product(*(enumerate(bds) for bds in chunks))
for new_index, chunk_shape in (zip(*ec) for ec in enumerated_chunks):
old_index = tuple(
0 if bd == (1,) else i for bd, i in zip(x.chunks, new_index[ndim_new:])
)
old_key = (x.name,) + old_index
new_key = (name,) + new_index
dsk[new_key] = (np.broadcast_to, old_key, quote(chunk_shape))
graph = HighLevelGraph.from_collections(name, dsk, dependencies=[x])
return Array(graph, name, chunks, dtype=x.dtype, meta=meta)
@derived_from(np)
def broadcast_arrays(*args, subok=False):
subok = bool(subok)
to_array = asanyarray if subok else asarray
args = tuple(to_array(e) for e in args)
# Unify uneven chunking
inds = [list(reversed(range(x.ndim))) for x in args]
uc_args = concat(zip(args, inds))
_, args = unify_chunks(*uc_args, warn=False)
shape = broadcast_shapes(*(e.shape for e in args))
chunks = broadcast_chunks(*(e.chunks for e in args))
if NUMPY_GE_200:
result = tuple(broadcast_to(e, shape=shape, chunks=chunks) for e in args)
else:
result = [broadcast_to(e, shape=shape, chunks=chunks) for e in args]
return result
def offset_func(func, offset, *args):
"""Offsets inputs by offset
>>> double = lambda x: x * 2
>>> f = offset_func(double, (10,))
>>> f(1)
22
>>> f(300)
620
"""
def _offset(*args):
args2 = list(map(add, args, offset))
return func(*args2)
with contextlib.suppress(Exception):
_offset.__name__ = "offset_" + func.__name__
return _offset
def chunks_from_arrays(arrays):
"""Chunks tuple from nested list of arrays
>>> x = np.array([1, 2])
>>> chunks_from_arrays([x, x])
((2, 2),)
>>> x = np.array([[1, 2]])
>>> chunks_from_arrays([[x], [x]])
((1, 1), (2,))
>>> x = np.array([[1, 2]])
>>> chunks_from_arrays([[x, x]])
((1,), (2, 2))
>>> chunks_from_arrays([1, 1])
((1, 1),)
"""
if not arrays:
return ()
result = []
dim = 0
def shape(x):
try:
return x.shape if x.shape else (1,)
except AttributeError:
return (1,)
while isinstance(arrays, (list, tuple)):
> result.append(tuple(shape(deepfirst(a))[dim] for a in arrays))
E IndexError: tuple index out of range
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/array/core.py:5237: IndexError
Check warning on line 0 in distributed.tests.test_stress
github-actions / Unit Test Results
8 out of 9 runs failed: test_stress_creation_and_deletion (distributed.tests.test_stress)
artifacts/macos-latest-3.12-default-ci1/pytest.xml [took 7s]
artifacts/ubuntu-latest-3.10-default-ci1/pytest.xml [took 7s]
artifacts/ubuntu-latest-3.10-no_expr-ci1/pytest.xml [took 7s]
artifacts/ubuntu-latest-3.10-no_queue-ci1/pytest.xml [took 7s]
artifacts/ubuntu-latest-3.11-default-ci1/pytest.xml [took 7s]
artifacts/ubuntu-latest-3.12-default-ci1/pytest.xml [took 8s]
artifacts/ubuntu-latest-mindeps-numpy-ci1/pytest.xml [took 7s]
artifacts/ubuntu-latest-mindeps-pandas-ci1/pytest.xml [took 7s]
Raw output
AssertionError: assert 'round-c642c375b3c4b5b78c9119e82165ed0a' == 8000884.93
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:33919', workers: 0, cores: 0, tasks: 0>
@pytest.mark.slow
@gen_cluster(
nthreads=[],
client=True,
scheduler_kwargs={"allowed_failures": 100_000},
)
async def test_stress_creation_and_deletion(c, s):
# Assertions are handled by the validate mechanism in the scheduler
pytest.importorskip("numpy")
da = pytest.importorskip("dask.array")
rng = da.random.RandomState(0)
x = rng.random(size=(2000, 2000), chunks=(100, 100))
y = ((x + 1).T + (x * 2) - x.mean(axis=1)).sum().round(2)
z = c.persist(y)
async def create_and_destroy_worker(delay):
start = time()
while time() < start + 5:
async with Worker(s.address, nthreads=2) as n:
await asyncio.sleep(delay)
await asyncio.gather(*(create_and_destroy_worker(0.1 * i) for i in range(20)))
async with Worker(s.address, nthreads=2):
> assert await c.compute(z) == 8000884.93
E AssertionError: assert 'round-c642c375b3c4b5b78c9119e82165ed0a' == 8000884.93
distributed/tests/test_stress.py:125: AssertionError
Check warning on line 0 in distributed.comm.tests.test_comms
github-actions / Unit Test Results
1 out of 7 runs failed: test_tls_comm_closed_implicit[tornado] (distributed.comm.tests.test_comms)
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
Raw output
ssl.SSLError: [SYS] unknown error (_ssl.c:2580)
tcp = <module 'distributed.comm.tcp' from '/home/runner/work/distributed/distributed/distributed/comm/tcp.py'>
@gen_test()
async def test_tls_comm_closed_implicit(tcp):
> await check_comm_closed_implicit("tls://127.0.0.1", **tls_kwargs)
distributed/comm/tests/test_comms.py:777:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/comm/tests/test_comms.py:763: in check_comm_closed_implicit
await comm.read()
distributed/comm/tcp.py:225: in read
frames_nosplit_nbytes_bin = await stream.read_bytes(fmt_size)
../../../miniconda3/envs/dask-distributed/lib/python3.11/site-packages/tornado/iostream.py:422: in read_bytes
self._try_inline_read()
../../../miniconda3/envs/dask-distributed/lib/python3.11/site-packages/tornado/iostream.py:836: in _try_inline_read
pos = self._read_to_buffer_loop()
../../../miniconda3/envs/dask-distributed/lib/python3.11/site-packages/tornado/iostream.py:750: in _read_to_buffer_loop
if self._read_to_buffer() == 0:
../../../miniconda3/envs/dask-distributed/lib/python3.11/site-packages/tornado/iostream.py:861: in _read_to_buffer
bytes_read = self.read_from_fd(buf)
../../../miniconda3/envs/dask-distributed/lib/python3.11/site-packages/tornado/iostream.py:1552: in read_from_fd
return self.socket.recv_into(buf, len(buf))
../../../miniconda3/envs/dask-distributed/lib/python3.11/ssl.py:1314: in recv_into
return self.read(nbytes, buffer)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <ssl.SSLSocket [closed] fd=-1, family=2, type=1, proto=0>, len = 65536
buffer = bytearray(b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x...0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')
def read(self, len=1024, buffer=None):
"""Read up to LEN bytes and return them.
Return zero-length string on EOF."""
self._checkClosed()
if self._sslobj is None:
raise ValueError("Read on closed or unwrapped SSL socket.")
try:
if buffer is not None:
> return self._sslobj.read(len, buffer)
E ssl.SSLError: [SYS] unknown error (_ssl.c:2580)
../../../miniconda3/envs/dask-distributed/lib/python3.11/ssl.py:1166: SSLError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
All 6 runs failed: test_rechunk_configuration[p2p-tasks] (distributed.shuffle.tests.test_rechunk)
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 39810932bc72623ea866af170a896ed5 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '39810932bc72623ea866af170a896ed5'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…wait self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:42801', workers: 0, cores: 0, tasks: 0>
config_value = 'tasks', keyword = 'p2p'
ws = (<Worker 'tcp://127.0.0.1:39467', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:35233', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.54100301, 0.39676024, 0.80875911, 0.57210751, 0.49079247,
0.98476684, 0.55330331, 0.02055366, 0.4693..., 0.57304034, 0.71045835, 0.80842066, 0.66522237,
0.10320915, 0.77049627, 0.21466926, 0.95670238, 0.30806249]])
@pytest.mark.parametrize("config_value", ["tasks", "p2p", None])
@pytest.mark.parametrize("keyword", ["tasks", "p2p", None])
@gen_cluster(client=True)
async def test_rechunk_configuration(c, s, *ws, config_value, keyword):
"""Try rechunking a random 1d matrix
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_1d
"""
a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
x = da.from_array(a, chunks=(10, 1))
new = ((1,) * 10, (10,))
config = {"array.rechunk.method": config_value} if config_value is not None else {}
with dask.config.set(config):
x2 = rechunk(x, chunks=new, method=keyword)
expected_algorithm = keyword if keyword is not None else config_value
if expected_algorithm == "p2p":
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
elif expected_algorithm == "tasks":
assert not any(
key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
)
# Neither is specified, so we choose the best one (see test_rechunk_heuristic for a full test of the heuristic)
else:
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
assert x2.chunks == new
> assert np.all(await c.compute(x2) == a)
distributed/shuffle/tests/test_rechunk.py:210:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 39810932bc72623ea866af170a896ed5 failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
All 6 runs failed: test_rechunk_configuration[p2p-p2p] (distributed.shuffle.tests.test_rechunk)
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P e07f19a3ad28b5edecde97f5c6e27492 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: 'e07f19a3ad28b5edecde97f5c6e27492'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and… await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:33805', workers: 0, cores: 0, tasks: 0>
config_value = 'p2p', keyword = 'p2p'
ws = (<Worker 'tcp://127.0.0.1:46259', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:43611', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.64399807, 0.31063259, 0.17543577, 0.82783572, 0.74638352,
0.19865671, 0.16140169, 0.6856798 , 0.0280..., 0.60435229, 0.88620013, 0.03875494, 0.32640197,
0.12674629, 0.42828326, 0.19082475, 0.71912404, 0.68470373]])
@pytest.mark.parametrize("config_value", ["tasks", "p2p", None])
@pytest.mark.parametrize("keyword", ["tasks", "p2p", None])
@gen_cluster(client=True)
async def test_rechunk_configuration(c, s, *ws, config_value, keyword):
"""Try rechunking a random 1d matrix
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_1d
"""
a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
x = da.from_array(a, chunks=(10, 1))
new = ((1,) * 10, (10,))
config = {"array.rechunk.method": config_value} if config_value is not None else {}
with dask.config.set(config):
x2 = rechunk(x, chunks=new, method=keyword)
expected_algorithm = keyword if keyword is not None else config_value
if expected_algorithm == "p2p":
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
elif expected_algorithm == "tasks":
assert not any(
key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
)
# Neither is specified, so we choose the best one (see test_rechunk_heuristic for a full test of the heuristic)
else:
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
assert x2.chunks == new
> assert np.all(await c.compute(x2) == a)
distributed/shuffle/tests/test_rechunk.py:210:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P e07f19a3ad28b5edecde97f5c6e27492 failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
All 6 runs failed: test_rechunk_configuration[p2p-None] (distributed.shuffle.tests.test_rechunk)
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P fe1d1aff4e980bf6369e4d4adf98f81b failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: 'fe1d1aff4e980bf6369e4d4adf98f81b'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…= await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:44351', workers: 0, cores: 0, tasks: 0>
config_value = None, keyword = 'p2p'
ws = (<Worker 'tcp://127.0.0.1:42471', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:41635', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.05373005, 0.04687012, 0.18475967, 0.63853183, 0.37113315,
0.34780472, 0.59389736, 0.66525719, 0.1175..., 0.43862109, 0.26487191, 0.08598918, 0.65790677,
0.24254732, 0.67961834, 0.27260373, 0.6418115 , 0.75985999]])
@pytest.mark.parametrize("config_value", ["tasks", "p2p", None])
@pytest.mark.parametrize("keyword", ["tasks", "p2p", None])
@gen_cluster(client=True)
async def test_rechunk_configuration(c, s, *ws, config_value, keyword):
"""Try rechunking a random 1d matrix
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_1d
"""
a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
x = da.from_array(a, chunks=(10, 1))
new = ((1,) * 10, (10,))
config = {"array.rechunk.method": config_value} if config_value is not None else {}
with dask.config.set(config):
x2 = rechunk(x, chunks=new, method=keyword)
expected_algorithm = keyword if keyword is not None else config_value
if expected_algorithm == "p2p":
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
elif expected_algorithm == "tasks":
assert not any(
key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
)
# Neither is specified, so we choose the best one (see test_rechunk_heuristic for a full test of the heuristic)
else:
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
assert x2.chunks == new
> assert np.all(await c.compute(x2) == a)
distributed/shuffle/tests/test_rechunk.py:210:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P fe1d1aff4e980bf6369e4d4adf98f81b failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
All 6 runs failed: test_rechunk_configuration[None-p2p] (distributed.shuffle.tests.test_rechunk)
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P aad191a3b35f209c6fd930cfec08e9c4 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: 'aad191a3b35f209c6fd930cfec08e9c4'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…= await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:32987', workers: 0, cores: 0, tasks: 0>
config_value = 'p2p', keyword = None
ws = (<Worker 'tcp://127.0.0.1:43423', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:40799', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.6420339 , 0.83288408, 0.4242612 , 0.40240756, 0.42555848,
0.10922281, 0.28362424, 0.53933638, 0.4652..., 0.49628131, 0.43796726, 0.73645182, 0.75935205,
0.0463592 , 0.93290705, 0.51500255, 0.631823 , 0.02839664]])
@pytest.mark.parametrize("config_value", ["tasks", "p2p", None])
@pytest.mark.parametrize("keyword", ["tasks", "p2p", None])
@gen_cluster(client=True)
async def test_rechunk_configuration(c, s, *ws, config_value, keyword):
"""Try rechunking a random 1d matrix
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_1d
"""
a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
x = da.from_array(a, chunks=(10, 1))
new = ((1,) * 10, (10,))
config = {"array.rechunk.method": config_value} if config_value is not None else {}
with dask.config.set(config):
x2 = rechunk(x, chunks=new, method=keyword)
expected_algorithm = keyword if keyword is not None else config_value
if expected_algorithm == "p2p":
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
elif expected_algorithm == "tasks":
assert not any(
key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
)
# Neither is specified, so we choose the best one (see test_rechunk_heuristic for a full test of the heuristic)
else:
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
assert x2.chunks == new
> assert np.all(await c.compute(x2) == a)
distributed/shuffle/tests/test_rechunk.py:210:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P aad191a3b35f209c6fd930cfec08e9c4 failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
All 6 runs failed: test_rechunk_configuration[None-None] (distributed.shuffle.tests.test_rechunk)
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P ed041f685e395fded8871501929b9a5e failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: 'ed041f685e395fded8871501929b9a5e'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and… = await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:42283', workers: 0, cores: 0, tasks: 0>
config_value = None, keyword = None
ws = (<Worker 'tcp://127.0.0.1:38627', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:40285', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.83001867, 0.80178537, 0.23115057, 0.22929814, 0.05555555,
0.53942166, 0.84898209, 0.27603218, 0.0611..., 0.75310379, 0.19612405, 0.08756546, 0.0405993 ,
0.66275933, 0.59495111, 0.77384036, 0.72799921, 0.69120823]])
@pytest.mark.parametrize("config_value", ["tasks", "p2p", None])
@pytest.mark.parametrize("keyword", ["tasks", "p2p", None])
@gen_cluster(client=True)
async def test_rechunk_configuration(c, s, *ws, config_value, keyword):
"""Try rechunking a random 1d matrix
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_1d
"""
a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
x = da.from_array(a, chunks=(10, 1))
new = ((1,) * 10, (10,))
config = {"array.rechunk.method": config_value} if config_value is not None else {}
with dask.config.set(config):
x2 = rechunk(x, chunks=new, method=keyword)
expected_algorithm = keyword if keyword is not None else config_value
if expected_algorithm == "p2p":
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
elif expected_algorithm == "tasks":
assert not any(
key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
)
# Neither is specified, so we choose the best one (see test_rechunk_heuristic for a full test of the heuristic)
else:
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
assert x2.chunks == new
> assert np.all(await c.compute(x2) == a)
distributed/shuffle/tests/test_rechunk.py:210:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P ed041f685e395fded8871501929b9a5e failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
All 6 runs failed: test_rechunk_heuristic[new0-p2p] (distributed.shuffle.tests.test_rechunk)
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 4eafd5a6ec56ad19493d1adddac5f652 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '4eafd5a6ec56ad19493d1adddac5f652'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…distributed/shuffle/_worker_plugin.py:411: in get_or_create_shuffle
return sync(
distributed/utils.py:439: in sync
raise error
distributed/utils.py:413: in f
result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/tornado/gen.py:766: in run
value = future.result()
distributed/shuffle/_worker_plugin.py:145: in get_or_create
shuffle_run = await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:41493', workers: 0, cores: 0, tasks: 0>
a = array([[0.37151424, 0.96300086, 0.03572343, ..., 0.13664586, 0.79334727,
0.45763255],
[0.6307737 , 0.37...21,
0.63065886],
[0.67865086, 0.35168033, 0.19536338, ..., 0.7557135 , 0.97195412,
0.57730087]])
b = <Worker 'tcp://127.0.0.1:39939', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>
new = ((1, 1, 1, 1, 1, 1, ...), (100,)), expected_algorithm = 'p2p'
@pytest.mark.parametrize(
["new", "expected_algorithm"],
[
# All-to-all rechunking defaults to P2P
(((1,) * 100, (100,)), "p2p"),
# Localized rechunking defaults to tasks
(((50, 50), (2,) * 50), "tasks"),
# Less local rechunking first defaults to tasks,
(((25, 25, 25, 25), (4,) * 25), "tasks"),
# then switches to p2p
(((10,) * 10, (10,) * 10), "p2p"),
],
)
@gen_cluster(client=True)
async def test_rechunk_heuristic(c, s, a, b, new, expected_algorithm):
a = np.random.default_rng().uniform(0, 1, 10000).reshape((100, 100))
x = da.from_array(a, chunks=(100, 1))
x2 = rechunk(x, chunks=new)
if expected_algorithm == "p2p":
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
else:
assert not any(
key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
)
assert x2.chunks == new
> assert np.all(await c.compute(x2) == a)
distributed/shuffle/tests/test_rechunk.py:239:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 4eafd5a6ec56ad19493d1adddac5f652 failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
All 6 runs failed: test_rechunk_heuristic[new3-p2p] (distributed.shuffle.tests.test_rechunk)
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P c806e7671339aecf80206effa5e88ca8 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: 'c806e7671339aecf80206effa5e88ca8'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…plugin.py:411: in get_or_create_shuffle
return sync(
distributed/utils.py:439: in sync
raise error
distributed/utils.py:413: in f
result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/tornado/gen.py:766: in run
value = future.result()
distributed/shuffle/_worker_plugin.py:145: in get_or_create
shuffle_run = await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:33641', workers: 0, cores: 0, tasks: 0>
a = array([[0.77218519, 0.98403101, 0.73971916, ..., 0.64295602, 0.86854558,
0.98601708],
[0.09089845, 0.03...8 ,
0.69222466],
[0.63829716, 0.61898815, 0.94981538, ..., 0.25538011, 0.30019752,
0.31170137]])
b = <Worker 'tcp://127.0.0.1:41547', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>
new = ((10, 10, 10, 10, 10, 10, ...), (10, 10, 10, 10, 10, 10, ...))
expected_algorithm = 'p2p'
@pytest.mark.parametrize(
["new", "expected_algorithm"],
[
# All-to-all rechunking defaults to P2P
(((1,) * 100, (100,)), "p2p"),
# Localized rechunking defaults to tasks
(((50, 50), (2,) * 50), "tasks"),
# Less local rechunking first defaults to tasks,
(((25, 25, 25, 25), (4,) * 25), "tasks"),
# then switches to p2p
(((10,) * 10, (10,) * 10), "p2p"),
],
)
@gen_cluster(client=True)
async def test_rechunk_heuristic(c, s, a, b, new, expected_algorithm):
a = np.random.default_rng().uniform(0, 1, 10000).reshape((100, 100))
x = da.from_array(a, chunks=(100, 1))
x2 = rechunk(x, chunks=new)
if expected_algorithm == "p2p":
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
else:
assert not any(
key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
)
assert x2.chunks == new
> assert np.all(await c.compute(x2) == a)
distributed/shuffle/tests/test_rechunk.py:239:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P c806e7671339aecf80206effa5e88ca8 failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
All 6 runs failed: test_cull_p2p_rechunk_independent_partitions (distributed.shuffle.tests.test_rechunk)
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
assert 57 < (228 / 4)
+ where 57 = len(HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7f948f9dc5e0>\n 0. getitem-d1db8f0a6402f938465bc6034b01ed58\n)
+ where HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7f948f9dc5e0>\n 0. getitem-d1db8f0a6402f938465bc6034b01ed58\n = dask.array<getitem, shape=(5, 2, 10), dtype=float64, chunksize=(5, 1, 10), chunktype=numpy.ndarray>.dask
+ and 228 = len(HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7f948e19f490>\n 0. rechunk-p2p-620b9aabeabeb94b707695b0db4eb07a\n)
+ where HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7f948e19f490>\n 0. rechunk-p2p-620b9aabeabeb94b707695b0db4eb07a\n = dask.array<rechunk-p2p, shape=(10, 10, 10), dtype=float64, chunksize=(5, 1, 10), chunktype=numpy.ndarray>.dask
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:36421', workers: 0, cores: 0, tasks: 0>
ws = (<Worker 'tcp://127.0.0.1:45917', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:36767', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[[0.32911429, 0.67956124, 0.09958113, 0.31750314, 0.54180337,
0.42813737, 0.07759511, 0.17287479, 0.19...0.91310963, 0.59520114, 0.99521652, 0.35626236,
0.01563698, 0.81999609, 0.8769825 , 0.71105292, 0.33717714]]])
x = dask.array<array, shape=(10, 10, 10), dtype=float64, chunksize=(1, 5, 1), chunktype=numpy.ndarray>
new = (5, 1, -1)
@gen_cluster(client=True)
async def test_cull_p2p_rechunk_independent_partitions(c, s, *ws):
a = np.random.default_rng().uniform(0, 1, 1000).reshape((10, 10, 10))
x = da.from_array(a, chunks=(1, 5, 1))
new = (5, 1, -1)
rechunked = rechunk(x, chunks=new, method="p2p")
(dsk,) = dask.optimize(rechunked)
culled = rechunked[:5, :2]
(dsk_culled,) = dask.optimize(culled)
# The culled graph requires only 1/2 of the input tasks
n_inputs = len(
[1 for key in dsk.dask.get_all_dependencies() if key[0].startswith("array-")]
)
n_culled_inputs = len(
[
1
for key in dsk_culled.dask.get_all_dependencies()
if key[0].startswith("array-")
]
)
assert n_culled_inputs == n_inputs / 4
# The culled graph should also have less than 1/4 the tasks
> assert len(dsk_culled.dask) < len(dsk.dask) / 4
E assert 57 < (228 / 4)
E + where 57 = len(HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7f948f9dc5e0>\n 0. getitem-d1db8f0a6402f938465bc6034b01ed58\n)
E + where HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7f948f9dc5e0>\n 0. getitem-d1db8f0a6402f938465bc6034b01ed58\n = dask.array<getitem, shape=(5, 2, 10), dtype=float64, chunksize=(5, 1, 10), chunktype=numpy.ndarray>.dask
E + and 228 = len(HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7f948e19f490>\n 0. rechunk-p2p-620b9aabeabeb94b707695b0db4eb07a\n)
E + where HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7f948e19f490>\n 0. rechunk-p2p-620b9aabeabeb94b707695b0db4eb07a\n = dask.array<rechunk-p2p, shape=(10, 10, 10), dtype=float64, chunksize=(5, 1, 10), chunktype=numpy.ndarray>.dask
distributed/shuffle/tests/test_rechunk.py:265: AssertionError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
All 6 runs failed: test_rechunk_expand (distributed.shuffle.tests.test_rechunk)
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P bbd47b3374e9ae0063a0808c2bcfb35f failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: 'bbd47b3374e9ae0063a0808c2bcfb35f'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…r
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
> yield
distributed/shuffle/_core.py:523:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/shuffle/_rechunk.py:170: in rechunk_transfer
return get_worker_plugin().add_partition(
distributed/shuffle/_worker_plugin.py:348: in add_partition
shuffle_run = self.get_or_create_shuffle(id)
distributed/shuffle/_worker_plugin.py:411: in get_or_create_shuffle
return sync(
distributed/utils.py:439: in sync
raise error
distributed/utils.py:413: in f
result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/tornado/gen.py:766: in run
value = future.result()
distributed/shuffle/_worker_plugin.py:145: in get_or_create
shuffle_run = await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:33125', workers: 0, cores: 0, tasks: 0>
ws = (<Worker 'tcp://127.0.0.1:36817', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:44325', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.20961821, 0.03813439, 0.16031038, 0.83008987, 0.47724403,
0.15157882, 0.65182507, 0.81244842, 0.2439..., 0.22958412, 0.27287624, 0.66399029, 0.51144192,
0.08086217, 0.82881384, 0.60756522, 0.08527088, 0.51785269]])
x = dask.array<array, shape=(10, 10), dtype=float64, chunksize=(5, 5), chunktype=numpy.ndarray>
y = dask.array<rechunk-p2p, shape=(10, 10), dtype=float64, chunksize=(3, 3), chunktype=numpy.ndarray>
@gen_cluster(client=True)
async def test_rechunk_expand(c, s, *ws):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_expand
"""
a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
x = da.from_array(a, chunks=(5, 5))
y = x.rechunk(chunks=((3, 3, 3, 1), (3, 3, 3, 1)), method="p2p")
> assert np.all(await c.compute(y) == a)
distributed/shuffle/tests/test_rechunk.py:377:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P bbd47b3374e9ae0063a0808c2bcfb35f failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
All 6 runs failed: test_rechunk_expand2 (distributed.shuffle.tests.test_rechunk)
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P b07060492e82fee76856c8bc3b212713 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: 'b07060492e82fee76856c8bc3b212713'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…rs(id: ShuffleId) -> Iterator[None]:
try:
> yield
distributed/shuffle/_core.py:523:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/shuffle/_rechunk.py:170: in rechunk_transfer
return get_worker_plugin().add_partition(
distributed/shuffle/_worker_plugin.py:348: in add_partition
shuffle_run = self.get_or_create_shuffle(id)
distributed/shuffle/_worker_plugin.py:411: in get_or_create_shuffle
return sync(
distributed/utils.py:439: in sync
raise error
distributed/utils.py:413: in f
result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/tornado/gen.py:766: in run
value = future.result()
distributed/shuffle/_worker_plugin.py:145: in get_or_create
shuffle_run = await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:44097', workers: 0, cores: 0, tasks: 0>
ws = (<Worker 'tcp://127.0.0.1:38371', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:38175', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = 3, b = 2
orig = array([[0.6017826 , 0.45461851, 0.95201606],
[0.20078242, 0.9710892 , 0.78108917],
[0.66423832, 0.27980603, 0.07144657]])
@gen_cluster(client=True)
async def test_rechunk_expand2(c, s, *ws):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_expand2
"""
(a, b) = (3, 2)
orig = np.random.default_rng().uniform(0, 1, a**b).reshape((a,) * b)
for off, off2 in product(range(1, a - 1), range(1, a - 1)):
old = ((a - off, off),) * b
x = da.from_array(orig, chunks=old)
new = ((a - off2, off2),) * b
assert np.all(await c.compute(x.rechunk(chunks=new, method="p2p")) == orig)
if a - off - off2 > 0:
new = ((off, a - off2 - off, off2),) * b
> y = await c.compute(x.rechunk(chunks=new, method="p2p"))
distributed/shuffle/tests/test_rechunk.py:396:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P b07060492e82fee76856c8bc3b212713 failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
5 out of 6 runs failed: test_rechunk_unknown_from_pandas (distributed.shuffle.tests.test_rechunk)
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P d30e37775e332baa036f77fc7f4c89c6 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: 'd30e37775e332baa036f77fc7f4c89c6'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…n.py:411: in get_or_create_shuffle
return sync(
distributed/utils.py:439: in sync
raise error
distributed/utils.py:413: in f
result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/tornado/gen.py:766: in run
value = future.result()
distributed/shuffle/_worker_plugin.py:145: in get_or_create
shuffle_run = await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:35243', workers: 0, cores: 0, tasks: 0>
ws = (<Worker 'tcp://127.0.0.1:36087', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:42641', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
pd = <module 'pandas' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.10/site-packages/pandas/__init__.py'>
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/dataframe/__init__.py'>
arr = array([[-2.65340026e-02, -6.62908051e-01, -3.74405314e-01,
1.62517681e+00, 4.22948452e-01, 1.14427358e+00,
... 1.24679262e+00, 5.95299989e-01,
3.79757690e-01, -4.28536040e-01, -8.06355440e-01,
1.46740310e+00]])
@gen_cluster(client=True)
async def test_rechunk_unknown_from_pandas(c, s, *ws):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_unknown_from_pandas
"""
pd = pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
arr = np.random.default_rng().standard_normal((50, 10))
x = dd.from_pandas(pd.DataFrame(arr), 2).values
result = x.rechunk((None, (5, 5)), method="p2p")
assert np.isnan(x.chunks[0]).all()
assert np.isnan(result.chunks[0]).all()
assert result.chunks[1] == (5, 5)
expected = da.from_array(arr, chunks=((25, 25), (10,))).rechunk(
(None, (5, 5)), method="p2p"
)
> assert_eq(await c.compute(result), await c.compute(expected))
distributed/shuffle/tests/test_rechunk.py:706:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P d30e37775e332baa036f77fc7f4c89c6 failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
5 out of 6 runs failed: test_rechunk_with_fully_unknown_dimension[x0-chunks0] (distributed.shuffle.tests.test_rechunk)
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 997a532b09b4edeae24e25db52e8ece8 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '997a532b09b4edeae24e25db52e8ece8'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…22: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:39695', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(50, 10), dtype=float64, chunksize=(25, 10), chunktype=numpy.ndarray>
chunks = (None, 5)
ws = (<Worker 'tcp://127.0.0.1:35371', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:41665', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/dataframe/__init__.py'>
@pytest.mark.parametrize(
"x, chunks",
[
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
(da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
(da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
],
)
@gen_cluster(client=True)
async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
"""
pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
y = dd.from_array(x).values
result = y.rechunk(chunks, method="p2p")
expected = x.rechunk(chunks, method="p2p")
assert_chunks_match(result.chunks, expected.chunks)
> assert_eq(await c.compute(result), await c.compute(expected))
distributed/shuffle/tests/test_rechunk.py:757:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 997a532b09b4edeae24e25db52e8ece8 failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
5 out of 6 runs failed: test_rechunk_with_fully_unknown_dimension[x1-chunks1] (distributed.shuffle.tests.test_rechunk)
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 997a532b09b4edeae24e25db52e8ece8 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '997a532b09b4edeae24e25db52e8ece8'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…y:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:40587', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(50, 10), dtype=float64, chunksize=(25, 10), chunktype=numpy.ndarray>
chunks = {1: 5}
ws = (<Worker 'tcp://127.0.0.1:38455', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:42433', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/dataframe/__init__.py'>
@pytest.mark.parametrize(
"x, chunks",
[
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
(da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
(da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
],
)
@gen_cluster(client=True)
async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
"""
pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
y = dd.from_array(x).values
result = y.rechunk(chunks, method="p2p")
expected = x.rechunk(chunks, method="p2p")
assert_chunks_match(result.chunks, expected.chunks)
> assert_eq(await c.compute(result), await c.compute(expected))
distributed/shuffle/tests/test_rechunk.py:757:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 997a532b09b4edeae24e25db52e8ece8 failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
5 out of 6 runs failed: test_rechunk_with_fully_unknown_dimension[x2-chunks2] (distributed.shuffle.tests.test_rechunk)
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 997a532b09b4edeae24e25db52e8ece8 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '997a532b09b4edeae24e25db52e8ece8'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…n _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:40857', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(50, 10), dtype=float64, chunksize=(25, 10), chunktype=numpy.ndarray>
chunks = (None, (5, 5))
ws = (<Worker 'tcp://127.0.0.1:44041', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:45303', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/dataframe/__init__.py'>
@pytest.mark.parametrize(
"x, chunks",
[
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
(da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
(da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
],
)
@gen_cluster(client=True)
async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
"""
pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
y = dd.from_array(x).values
result = y.rechunk(chunks, method="p2p")
expected = x.rechunk(chunks, method="p2p")
assert_chunks_match(result.chunks, expected.chunks)
> assert_eq(await c.compute(result), await c.compute(expected))
distributed/shuffle/tests/test_rechunk.py:757:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 997a532b09b4edeae24e25db52e8ece8 failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
5 out of 6 runs failed: test_rechunk_with_fully_unknown_dimension[x3-chunks3] (distributed.shuffle.tests.test_rechunk)
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P cd806ec16ba2329f2cd86e74fd632cc1 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: 'cd806ec16ba2329f2cd86e74fd632cc1'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…22: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:39571', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(100, 10), dtype=float64, chunksize=(5, 10), chunktype=numpy.ndarray>
chunks = (None, 5)
ws = (<Worker 'tcp://127.0.0.1:34065', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:43371', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/dataframe/__init__.py'>
@pytest.mark.parametrize(
"x, chunks",
[
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
(da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
(da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
],
)
@gen_cluster(client=True)
async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
"""
pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
y = dd.from_array(x).values
result = y.rechunk(chunks, method="p2p")
expected = x.rechunk(chunks, method="p2p")
assert_chunks_match(result.chunks, expected.chunks)
> assert_eq(await c.compute(result), await c.compute(expected))
distributed/shuffle/tests/test_rechunk.py:757:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P cd806ec16ba2329f2cd86e74fd632cc1 failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
5 out of 6 runs failed: test_rechunk_with_fully_unknown_dimension[x4-chunks4] (distributed.shuffle.tests.test_rechunk)
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P cd806ec16ba2329f2cd86e74fd632cc1 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: 'cd806ec16ba2329f2cd86e74fd632cc1'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…y:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:42723', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(100, 10), dtype=float64, chunksize=(5, 10), chunktype=numpy.ndarray>
chunks = {1: 5}
ws = (<Worker 'tcp://127.0.0.1:40963', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:41041', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/dataframe/__init__.py'>
@pytest.mark.parametrize(
"x, chunks",
[
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
(da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
(da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
],
)
@gen_cluster(client=True)
async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
"""
pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
y = dd.from_array(x).values
result = y.rechunk(chunks, method="p2p")
expected = x.rechunk(chunks, method="p2p")
assert_chunks_match(result.chunks, expected.chunks)
> assert_eq(await c.compute(result), await c.compute(expected))
distributed/shuffle/tests/test_rechunk.py:757:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P cd806ec16ba2329f2cd86e74fd632cc1 failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
5 out of 6 runs failed: test_rechunk_with_fully_unknown_dimension[x5-chunks5] (distributed.shuffle.tests.test_rechunk)
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P cd806ec16ba2329f2cd86e74fd632cc1 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: 'cd806ec16ba2329f2cd86e74fd632cc1'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…n _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:33355', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(100, 10), dtype=float64, chunksize=(5, 10), chunktype=numpy.ndarray>
chunks = (None, (5, 5))
ws = (<Worker 'tcp://127.0.0.1:38757', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:41363', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/dataframe/__init__.py'>
@pytest.mark.parametrize(
"x, chunks",
[
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
(da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
(da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
],
)
@gen_cluster(client=True)
async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
"""
pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
y = dd.from_array(x).values
result = y.rechunk(chunks, method="p2p")
expected = x.rechunk(chunks, method="p2p")
assert_chunks_match(result.chunks, expected.chunks)
> assert_eq(await c.compute(result), await c.compute(expected))
distributed/shuffle/tests/test_rechunk.py:757:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P cd806ec16ba2329f2cd86e74fd632cc1 failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
5 out of 6 runs failed: test_rechunk_with_fully_unknown_dimension[x6-chunks6] (distributed.shuffle.tests.test_rechunk)
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P d1cb5122ca49df832e6e59ceb0060c25 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: 'd1cb5122ca49df832e6e59ceb0060c25'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…22: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:45147', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(10, 10), dtype=float64, chunksize=(10, 10), chunktype=numpy.ndarray>
chunks = (None, 5)
ws = (<Worker 'tcp://127.0.0.1:33203', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:43661', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/dataframe/__init__.py'>
@pytest.mark.parametrize(
"x, chunks",
[
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
(da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
(da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
],
)
@gen_cluster(client=True)
async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
"""
pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
y = dd.from_array(x).values
result = y.rechunk(chunks, method="p2p")
expected = x.rechunk(chunks, method="p2p")
assert_chunks_match(result.chunks, expected.chunks)
> assert_eq(await c.compute(result), await c.compute(expected))
distributed/shuffle/tests/test_rechunk.py:757:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P d1cb5122ca49df832e6e59ceb0060c25 failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
5 out of 6 runs failed: test_rechunk_with_fully_unknown_dimension[x7-chunks7] (distributed.shuffle.tests.test_rechunk)
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P d1cb5122ca49df832e6e59ceb0060c25 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: 'd1cb5122ca49df832e6e59ceb0060c25'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…y:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:38539', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(10, 10), dtype=float64, chunksize=(10, 10), chunktype=numpy.ndarray>
chunks = {1: 5}
ws = (<Worker 'tcp://127.0.0.1:46847', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:34357', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/dataframe/__init__.py'>
@pytest.mark.parametrize(
"x, chunks",
[
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
(da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
(da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
],
)
@gen_cluster(client=True)
async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
"""
pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
y = dd.from_array(x).values
result = y.rechunk(chunks, method="p2p")
expected = x.rechunk(chunks, method="p2p")
assert_chunks_match(result.chunks, expected.chunks)
> assert_eq(await c.compute(result), await c.compute(expected))
distributed/shuffle/tests/test_rechunk.py:757:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P d1cb5122ca49df832e6e59ceb0060c25 failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
5 out of 6 runs failed: test_rechunk_with_fully_unknown_dimension[x8-chunks8] (distributed.shuffle.tests.test_rechunk)
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P d1cb5122ca49df832e6e59ceb0060c25 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: 'd1cb5122ca49df832e6e59ceb0060c25'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…n _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:37997', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(10, 10), dtype=float64, chunksize=(10, 10), chunktype=numpy.ndarray>
chunks = (None, (5, 5))
ws = (<Worker 'tcp://127.0.0.1:35021', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:42729', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/dataframe/__init__.py'>
@pytest.mark.parametrize(
"x, chunks",
[
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
(da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
(da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
],
)
@gen_cluster(client=True)
async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
"""
pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
y = dd.from_array(x).values
result = y.rechunk(chunks, method="p2p")
expected = x.rechunk(chunks, method="p2p")
assert_chunks_match(result.chunks, expected.chunks)
> assert_eq(await c.compute(result), await c.compute(expected))
distributed/shuffle/tests/test_rechunk.py:757:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P d1cb5122ca49df832e6e59ceb0060c25 failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
5 out of 6 runs failed: test_rechunk_with_fully_unknown_dimension[x9-chunks9] (distributed.shuffle.tests.test_rechunk)
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P a5d6a357fabfab24dace4a1ae3b6eb0d failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: 'a5d6a357fabfab24dace4a1ae3b6eb0d'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:36439', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(10, 10), dtype=float64, chunksize=(10, 2), chunktype=numpy.ndarray>
chunks = (None, 5)
ws = (<Worker 'tcp://127.0.0.1:37419', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:45359', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/dataframe/__init__.py'>
@pytest.mark.parametrize(
"x, chunks",
[
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
(da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
(da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
],
)
@gen_cluster(client=True)
async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
"""
pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
y = dd.from_array(x).values
result = y.rechunk(chunks, method="p2p")
expected = x.rechunk(chunks, method="p2p")
assert_chunks_match(result.chunks, expected.chunks)
> assert_eq(await c.compute(result), await c.compute(expected))
distributed/shuffle/tests/test_rechunk.py:757:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:153: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P a5d6a357fabfab24dace4a1ae3b6eb0d failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError