Skip to content

Commit

Permalink
Monkeypatch protocol.loads ala dask/distributed#8216 (#1247)
Browse files Browse the repository at this point in the history
In versions of distributed after dask/distributed#8067 but before dask/distributed#8216, we must patch protocol.loads to include the same decompression fix.

Authors:
  - Lawrence Mitchell (https://github.com/wence-)

Approvers:
  - Peter Andreas Entschev (https://github.com/pentschev)

URL: #1247
  • Loading branch information
wence- authored Sep 27, 2023
1 parent 6bd4ba4 commit 7400f95
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 0 deletions.
1 change: 1 addition & 0 deletions dask_cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

__version__ = "23.10.00"

from . import compat

# Monkey patching Dask to make use of explicit-comms when `DASK_EXPLICIT_COMMS=True`
dask.dataframe.shuffle.rearrange_by_column = get_rearrange_by_column_wrapper(
Expand Down
118 changes: 118 additions & 0 deletions dask_cuda/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import pickle

import msgpack
from packaging.version import Version

import dask
import distributed
import distributed.comm.utils
import distributed.protocol
from distributed.comm.utils import OFFLOAD_THRESHOLD, nbytes, offload
from distributed.protocol.core import (
Serialized,
decompress,
logger,
merge_and_deserialize,
msgpack_decode_default,
msgpack_opts,
)

if Version(distributed.__version__) >= Version("2023.8.1"):
# Monkey-patch protocol.core.loads (and its users)
async def from_frames(
frames, deserialize=True, deserializers=None, allow_offload=True
):
"""
Unserialize a list of Distributed protocol frames.
"""
size = False

def _from_frames():
try:
# Patched code
return loads(
frames, deserialize=deserialize, deserializers=deserializers
)
# end patched code
except EOFError:
if size > 1000:
datastr = "[too large to display]"
else:
datastr = frames
# Aid diagnosing
logger.error("truncated data stream (%d bytes): %s", size, datastr)
raise

if allow_offload and deserialize and OFFLOAD_THRESHOLD:
size = sum(map(nbytes, frames))
if (
allow_offload
and deserialize
and OFFLOAD_THRESHOLD
and size > OFFLOAD_THRESHOLD
):
res = await offload(_from_frames)
else:
res = _from_frames()

return res

def loads(frames, deserialize=True, deserializers=None):
"""Transform bytestream back into Python value"""

allow_pickle = dask.config.get("distributed.scheduler.pickle")

try:

def _decode_default(obj):
offset = obj.get("__Serialized__", 0)
if offset > 0:
sub_header = msgpack.loads(
frames[offset],
object_hook=msgpack_decode_default,
use_list=False,
**msgpack_opts,
)
offset += 1
sub_frames = frames[offset : offset + sub_header["num-sub-frames"]]
if deserialize:
if "compression" in sub_header:
sub_frames = decompress(sub_header, sub_frames)
return merge_and_deserialize(
sub_header, sub_frames, deserializers=deserializers
)
else:
return Serialized(sub_header, sub_frames)

offset = obj.get("__Pickled__", 0)
if offset > 0:
sub_header = msgpack.loads(frames[offset])
offset += 1
sub_frames = frames[offset : offset + sub_header["num-sub-frames"]]
# Patched code
if "compression" in sub_header:
sub_frames = decompress(sub_header, sub_frames)
# end patched code
if allow_pickle:
return pickle.loads(
sub_header["pickled-obj"], buffers=sub_frames
)
else:
raise ValueError(
"Unpickle on the Scheduler isn't allowed, "
"set `distributed.scheduler.pickle=true`"
)

return msgpack_decode_default(obj)

return msgpack.loads(
frames[0], object_hook=_decode_default, use_list=False, **msgpack_opts
)

except Exception:
logger.critical("Failed to deserialize", exc_info=True)
raise

distributed.protocol.loads = loads
distributed.protocol.core.loads = loads
distributed.comm.utils.from_frames = from_frames
18 changes: 18 additions & 0 deletions dask_cuda/tests/test_from_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest

import dask.array as da
from distributed import Client

from dask_cuda import LocalCUDACluster

pytest.importorskip("ucp")
cupy = pytest.importorskip("cupy")


@pytest.mark.parametrize("protocol", ["ucx", "tcp"])
def test_ucx_from_array(protocol):
N = 10_000
with LocalCUDACluster(protocol=protocol) as cluster:
with Client(cluster):
val = da.from_array(cupy.arange(N), chunks=(N // 10,)).sum().compute()
assert val == (N * (N - 1)) // 2

0 comments on commit 7400f95

Please sign in to comment.