-
Notifications
You must be signed in to change notification settings - Fork 94
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Monkeypatch protocol.loads ala dask/distributed#8216 (#1247)
In versions of distributed after dask/distributed#8067 but before dask/distributed#8216, we must patch protocol.loads to include the same decompression fix. Authors: - Lawrence Mitchell (https://github.com/wence-) Approvers: - Peter Andreas Entschev (https://github.com/pentschev) URL: #1247
- Loading branch information
Showing
3 changed files
with
137 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import pickle | ||
|
||
import msgpack | ||
from packaging.version import Version | ||
|
||
import dask | ||
import distributed | ||
import distributed.comm.utils | ||
import distributed.protocol | ||
from distributed.comm.utils import OFFLOAD_THRESHOLD, nbytes, offload | ||
from distributed.protocol.core import ( | ||
Serialized, | ||
decompress, | ||
logger, | ||
merge_and_deserialize, | ||
msgpack_decode_default, | ||
msgpack_opts, | ||
) | ||
|
||
if Version(distributed.__version__) >= Version("2023.8.1"): | ||
# Monkey-patch protocol.core.loads (and its users) | ||
async def from_frames( | ||
frames, deserialize=True, deserializers=None, allow_offload=True | ||
): | ||
""" | ||
Unserialize a list of Distributed protocol frames. | ||
""" | ||
size = False | ||
|
||
def _from_frames(): | ||
try: | ||
# Patched code | ||
return loads( | ||
frames, deserialize=deserialize, deserializers=deserializers | ||
) | ||
# end patched code | ||
except EOFError: | ||
if size > 1000: | ||
datastr = "[too large to display]" | ||
else: | ||
datastr = frames | ||
# Aid diagnosing | ||
logger.error("truncated data stream (%d bytes): %s", size, datastr) | ||
raise | ||
|
||
if allow_offload and deserialize and OFFLOAD_THRESHOLD: | ||
size = sum(map(nbytes, frames)) | ||
if ( | ||
allow_offload | ||
and deserialize | ||
and OFFLOAD_THRESHOLD | ||
and size > OFFLOAD_THRESHOLD | ||
): | ||
res = await offload(_from_frames) | ||
else: | ||
res = _from_frames() | ||
|
||
return res | ||
|
||
def loads(frames, deserialize=True, deserializers=None): | ||
"""Transform bytestream back into Python value""" | ||
|
||
allow_pickle = dask.config.get("distributed.scheduler.pickle") | ||
|
||
try: | ||
|
||
def _decode_default(obj): | ||
offset = obj.get("__Serialized__", 0) | ||
if offset > 0: | ||
sub_header = msgpack.loads( | ||
frames[offset], | ||
object_hook=msgpack_decode_default, | ||
use_list=False, | ||
**msgpack_opts, | ||
) | ||
offset += 1 | ||
sub_frames = frames[offset : offset + sub_header["num-sub-frames"]] | ||
if deserialize: | ||
if "compression" in sub_header: | ||
sub_frames = decompress(sub_header, sub_frames) | ||
return merge_and_deserialize( | ||
sub_header, sub_frames, deserializers=deserializers | ||
) | ||
else: | ||
return Serialized(sub_header, sub_frames) | ||
|
||
offset = obj.get("__Pickled__", 0) | ||
if offset > 0: | ||
sub_header = msgpack.loads(frames[offset]) | ||
offset += 1 | ||
sub_frames = frames[offset : offset + sub_header["num-sub-frames"]] | ||
# Patched code | ||
if "compression" in sub_header: | ||
sub_frames = decompress(sub_header, sub_frames) | ||
# end patched code | ||
if allow_pickle: | ||
return pickle.loads( | ||
sub_header["pickled-obj"], buffers=sub_frames | ||
) | ||
else: | ||
raise ValueError( | ||
"Unpickle on the Scheduler isn't allowed, " | ||
"set `distributed.scheduler.pickle=true`" | ||
) | ||
|
||
return msgpack_decode_default(obj) | ||
|
||
return msgpack.loads( | ||
frames[0], object_hook=_decode_default, use_list=False, **msgpack_opts | ||
) | ||
|
||
except Exception: | ||
logger.critical("Failed to deserialize", exc_info=True) | ||
raise | ||
|
||
distributed.protocol.loads = loads | ||
distributed.protocol.core.loads = loads | ||
distributed.comm.utils.from_frames = from_frames |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import pytest | ||
|
||
import dask.array as da | ||
from distributed import Client | ||
|
||
from dask_cuda import LocalCUDACluster | ||
|
||
pytest.importorskip("ucp") | ||
cupy = pytest.importorskip("cupy") | ||
|
||
|
||
@pytest.mark.parametrize("protocol", ["ucx", "tcp"]) | ||
def test_ucx_from_array(protocol): | ||
N = 10_000 | ||
with LocalCUDACluster(protocol=protocol) as cluster: | ||
with Client(cluster): | ||
val = da.from_array(cupy.arange(N), chunks=(N // 10,)).sum().compute() | ||
assert val == (N * (N - 1)) // 2 |