diff --git a/python/cudf/cudf/_lib/copying.pyx b/python/cudf/cudf/_lib/copying.pyx index 1f3f03f4be1..a7ea9c25a86 100644 --- a/python/cudf/cudf/_lib/copying.pyx +++ b/python/cudf/cudf/_lib/copying.pyx @@ -1,7 +1,5 @@ # Copyright (c) 2020-2024, NVIDIA CORPORATION. -import pickle - from libcpp cimport bool import pylibcudf @@ -358,14 +356,13 @@ class PackedColumns(Serializable): header["index-names"] = self.index_names header["metadata"] = self._metadata.tobytes() for name, dtype in self.column_dtypes.items(): - dtype_header, dtype_frames = dtype.serialize() + dtype_header, dtype_frames = dtype.device_serialize() self.column_dtypes[name] = ( dtype_header, (len(frames), len(frames) + len(dtype_frames)), ) frames.extend(dtype_frames) header["column-dtypes"] = self.column_dtypes - header["type-serialized"] = pickle.dumps(type(self)) return header, frames @classmethod @@ -373,9 +370,9 @@ class PackedColumns(Serializable): column_dtypes = {} for name, dtype in header["column-dtypes"].items(): dtype_header, (start, stop) = dtype - column_dtypes[name] = pickle.loads( - dtype_header["type-serialized"] - ).deserialize(dtype_header, frames[start:stop]) + column_dtypes[name] = Serializable.device_deserialize( + dtype_header, frames[start:stop] + ) return cls( plc.contiguous_split.pack( plc.contiguous_split.unpack_from_memoryviews( diff --git a/python/cudf/cudf/core/_base_index.py b/python/cudf/cudf/core/_base_index.py index 2df154ee112..1b6152b81ca 100644 --- a/python/cudf/cudf/core/_base_index.py +++ b/python/cudf/cudf/core/_base_index.py @@ -2,7 +2,6 @@ from __future__ import annotations -import pickle import warnings from functools import cached_property from typing import TYPE_CHECKING, Any, Literal @@ -330,13 +329,6 @@ def get_level_values(self, level): else: raise KeyError(f"Requested level with name {level} " "not found") - @classmethod - def deserialize(cls, header, frames): - # Dispatch deserialization to the appropriate index type in case - # deserialization is ever attempted with the base class directly. - idx_type = pickle.loads(header["type-serialized"]) - return idx_type.deserialize(header, frames) - @property def names(self): """ diff --git a/python/cudf/cudf/core/abc.py b/python/cudf/cudf/core/abc.py index ce6bb83bc77..c8ea03b04fe 100644 --- a/python/cudf/cudf/core/abc.py +++ b/python/cudf/cudf/core/abc.py @@ -1,8 +1,6 @@ # Copyright (c) 2020-2024, NVIDIA CORPORATION. """Common abstract base classes for cudf.""" -import pickle - import numpy import cudf @@ -22,6 +20,14 @@ class Serializable: latter converts back from that representation into an equivalent object. """ + # A mapping from class names to the classes themselves. This is used to + # reconstruct the correct class when deserializing an object. + _name_type_map: dict = {} + + def __init_subclass__(cls, /, **kwargs): + super().__init_subclass__(**kwargs) + cls._name_type_map[cls.__name__] = cls + def serialize(self): """Generate an equivalent serializable representation of an object. @@ -98,7 +104,7 @@ def device_serialize(self): ) for f in frames ) - header["type-serialized"] = pickle.dumps(type(self)) + header["type-serialized-name"] = type(self).__name__ header["is-cuda"] = [ hasattr(f, "__cuda_array_interface__") for f in frames ] @@ -128,10 +134,10 @@ def device_deserialize(cls, header, frames): :meta private: """ - typ = pickle.loads(header["type-serialized"]) + typ = cls._name_type_map[header["type-serialized-name"]] frames = [ cudf.core.buffer.as_buffer(f) if c else memoryview(f) - for c, f in zip(header["is-cuda"], frames) + for c, f in zip(header["is-cuda"], frames, strict=True) ] return typ.deserialize(header, frames) diff --git a/python/cudf/cudf/core/buffer/buffer.py b/python/cudf/cudf/core/buffer/buffer.py index ffa306bf93f..625938ca168 100644 --- a/python/cudf/cudf/core/buffer/buffer.py +++ b/python/cudf/cudf/core/buffer/buffer.py @@ -3,7 +3,6 @@ from __future__ import annotations import math -import pickle import weakref from types import SimpleNamespace from typing import TYPE_CHECKING, Any, Literal @@ -432,8 +431,7 @@ def serialize(self) -> tuple[dict, list]: second element is a list containing single frame. """ header: dict[str, Any] = {} - header["type-serialized"] = pickle.dumps(type(self)) - header["owner-type-serialized"] = pickle.dumps(type(self._owner)) + header["owner-type-serialized-name"] = type(self._owner).__name__ header["frame_count"] = 1 frames = [self] return header, frames @@ -460,7 +458,9 @@ def deserialize(cls, header: dict, frames: list) -> Self: if isinstance(frame, cls): return frame # The frame is already deserialized - owner_type: BufferOwner = pickle.loads(header["owner-type-serialized"]) + owner_type: BufferOwner = Serializable._name_type_map[ + header["owner-type-serialized-name"] + ] if hasattr(frame, "__cuda_array_interface__"): owner = owner_type.from_device_memory(frame, exposed=False) else: diff --git a/python/cudf/cudf/core/buffer/spillable_buffer.py b/python/cudf/cudf/core/buffer/spillable_buffer.py index 7305ff651c6..cbb65229933 100644 --- a/python/cudf/cudf/core/buffer/spillable_buffer.py +++ b/python/cudf/cudf/core/buffer/spillable_buffer.py @@ -3,7 +3,6 @@ from __future__ import annotations import collections.abc -import pickle import time import weakref from threading import RLock @@ -415,8 +414,7 @@ def serialize(self) -> tuple[dict, list]: header: dict[str, Any] = {} frames: list[Buffer | memoryview] with self._owner.lock: - header["type-serialized"] = pickle.dumps(self.__class__) - header["owner-type-serialized"] = pickle.dumps(type(self._owner)) + header["owner-type-serialized-name"] = type(self._owner).__name__ header["frame_count"] = 1 if self.is_spilled: frames = [self.memoryview()] diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index b317858077f..68307f0e109 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -2,7 +2,6 @@ from __future__ import annotations -import pickle from collections import abc from collections.abc import MutableSequence, Sequence from functools import cached_property @@ -1294,28 +1293,27 @@ def serialize(self) -> tuple[dict, list]: header: dict[Any, Any] = {} frames = [] - header["type-serialized"] = pickle.dumps(type(self)) try: - dtype, dtype_frames = self.dtype.serialize() + dtype, dtype_frames = self.dtype.device_serialize() header["dtype"] = dtype frames.extend(dtype_frames) header["dtype-is-cudf-serialized"] = True except AttributeError: - header["dtype"] = pickle.dumps(self.dtype) + header["dtype"] = self.dtype.str header["dtype-is-cudf-serialized"] = False if self.data is not None: - data_header, data_frames = self.data.serialize() + data_header, data_frames = self.data.device_serialize() header["data"] = data_header frames.extend(data_frames) if self.mask is not None: - mask_header, mask_frames = self.mask.serialize() + mask_header, mask_frames = self.mask.device_serialize() header["mask"] = mask_header frames.extend(mask_frames) if self.children: child_headers, child_frames = zip( - *(c.serialize() for c in self.children) + *(c.device_serialize() for c in self.children) ) header["subheaders"] = list(child_headers) frames.extend(chain(*child_frames)) @@ -1327,8 +1325,7 @@ def serialize(self) -> tuple[dict, list]: def deserialize(cls, header: dict, frames: list) -> ColumnBase: def unpack(header, frames) -> tuple[Any, list]: count = header["frame_count"] - klass = pickle.loads(header["type-serialized"]) - obj = klass.deserialize(header, frames[:count]) + obj = cls.device_deserialize(header, frames[:count]) return obj, frames[count:] assert header["frame_count"] == len(frames), ( @@ -1338,7 +1335,7 @@ def unpack(header, frames) -> tuple[Any, list]: if header["dtype-is-cudf-serialized"]: dtype, frames = unpack(header["dtype"], frames) else: - dtype = pickle.loads(header["dtype"]) + dtype = np.dtype(header["dtype"]) if "data" in header: data, frames = unpack(header["data"], frames) else: @@ -2307,7 +2304,9 @@ def serialize_columns(columns: list[ColumnBase]) -> tuple[list[dict], list]: frames = [] if len(columns) > 0: - header_columns = [c.serialize() for c in columns] + header_columns: list[tuple[dict, list]] = [ + c.device_serialize() for c in columns + ] headers, column_frames = zip(*header_columns) for f in column_frames: frames.extend(f) @@ -2324,7 +2323,7 @@ def deserialize_columns(headers: list[dict], frames: list) -> list[ColumnBase]: for meta in headers: col_frame_count = meta["frame_count"] - col_typ = pickle.loads(meta["type-serialized"]) + col_typ = Serializable._name_type_map[meta["type-serialized-name"]] colobj = col_typ.deserialize(meta, frames[:col_frame_count]) columns.append(colobj) # Advance frames diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py index 325601e5311..b74128a8a61 100644 --- a/python/cudf/cudf/core/dataframe.py +++ b/python/cudf/cudf/core/dataframe.py @@ -7,7 +7,6 @@ import itertools import numbers import os -import pickle import re import sys import textwrap @@ -50,7 +49,6 @@ ) from cudf.core import column, df_protocol, indexing_utils, reshape from cudf.core._compat import PANDAS_LT_300 -from cudf.core.abc import Serializable from cudf.core.buffer import acquire_spill_lock, as_buffer from cudf.core.column import ( CategoricalColumn, @@ -588,7 +586,7 @@ class _DataFrameiAtIndexer(_DataFrameIlocIndexer): pass -class DataFrame(IndexedFrame, Serializable, GetAttrGetItemMixin): +class DataFrame(IndexedFrame, GetAttrGetItemMixin): """ A GPU Dataframe object. @@ -1190,7 +1188,7 @@ def _constructor_expanddim(self): def serialize(self): header, frames = super().serialize() - header["index"], index_frames = self.index.serialize() + header["index"], index_frames = self.index.device_serialize() header["index_frame_count"] = len(index_frames) # For backwards compatibility with older versions of cuDF, index # columns are placed before data columns. @@ -1205,8 +1203,7 @@ def deserialize(cls, header, frames): header, frames[header["index_frame_count"] :] ) - idx_typ = pickle.loads(header["index"]["type-serialized"]) - index = idx_typ.deserialize(header["index"], frames[:index_nframes]) + index = cls.device_deserialize(header["index"], frames[:index_nframes]) obj.index = index return obj diff --git a/python/cudf/cudf/core/dtypes.py b/python/cudf/cudf/core/dtypes.py index 801020664da..9bb29f1920a 100644 --- a/python/cudf/cudf/core/dtypes.py +++ b/python/cudf/cudf/core/dtypes.py @@ -3,7 +3,6 @@ import decimal import operator -import pickle import textwrap import warnings from functools import cached_property @@ -91,13 +90,13 @@ def dtype(arbitrary): raise TypeError(f"Cannot interpret {arbitrary} as a valid cuDF dtype") -def _decode_type( +def _check_type( cls: type, header: dict, frames: list, is_valid_class: Callable[[type, type], bool] = operator.is_, -) -> tuple[dict, list, type]: - """Decode metadata-encoded type and check validity +) -> None: + """Perform metadata-encoded type and check validity Parameters ---------- @@ -112,12 +111,6 @@ class performing deserialization serialization by `cls` (default is to check type equality), called as `is_valid_class(decoded_class, cls)`. - Returns - ------- - tuple - Tuple of validated headers, frames, and the decoded class - constructor. - Raises ------ AssertionError @@ -128,11 +121,11 @@ class performing deserialization f"Deserialization expected {header['frame_count']} frames, " f"but received {len(frames)}." ) - klass = pickle.loads(header["type-serialized"]) + klass = Serializable._name_type_map[header["type-serialized-name"]] assert is_valid_class( - klass, cls + klass, + cls, ), f"Header-encoded {klass=} does not match decoding {cls=}." - return header, frames, klass class _BaseDtype(ExtensionDtype, Serializable): @@ -305,13 +298,14 @@ def construct_from_string(self): def serialize(self): header = {} - header["type-serialized"] = pickle.dumps(type(self)) header["ordered"] = self.ordered frames = [] if self.categories is not None: - categories_header, categories_frames = self.categories.serialize() + categories_header, categories_frames = ( + self.categories.device_serialize() + ) header["categories"] = categories_header frames.extend(categories_frames) header["frame_count"] = len(frames) @@ -319,15 +313,14 @@ def serialize(self): @classmethod def deserialize(cls, header, frames): - header, frames, klass = _decode_type(cls, header, frames) + _check_type(cls, header, frames) ordered = header["ordered"] categories_header = header["categories"] categories_frames = frames - categories_type = pickle.loads(categories_header["type-serialized"]) - categories = categories_type.deserialize( + categories = Serializable.device_deserialize( categories_header, categories_frames ) - return klass(categories=categories, ordered=ordered) + return cls(categories=categories, ordered=ordered) def __repr__(self): return self.to_pandas().__repr__() @@ -495,12 +488,13 @@ def __hash__(self): def serialize(self) -> tuple[dict, list]: header: dict[str, Dtype] = {} - header["type-serialized"] = pickle.dumps(type(self)) frames = [] if isinstance(self.element_type, _BaseDtype): - header["element-type"], frames = self.element_type.serialize() + header["element-type"], frames = ( + self.element_type.device_serialize() + ) else: header["element-type"] = getattr( self.element_type, "name", self.element_type @@ -510,14 +504,14 @@ def serialize(self) -> tuple[dict, list]: @classmethod def deserialize(cls, header: dict, frames: list): - header, frames, klass = _decode_type(cls, header, frames) + _check_type(cls, header, frames) if isinstance(header["element-type"], dict): - element_type = pickle.loads( - header["element-type"]["type-serialized"] - ).deserialize(header["element-type"], frames) + element_type = Serializable.device_deserialize( + header["element-type"], frames + ) else: element_type = header["element-type"] - return klass(element_type=element_type) + return cls(element_type=element_type) @cached_property def itemsize(self): @@ -641,7 +635,6 @@ def __hash__(self): def serialize(self) -> tuple[dict, list]: header: dict[str, Any] = {} - header["type-serialized"] = pickle.dumps(type(self)) frames: list[Buffer] = [] @@ -649,33 +642,31 @@ def serialize(self) -> tuple[dict, list]: for k, dtype in self.fields.items(): if isinstance(dtype, _BaseDtype): - dtype_header, dtype_frames = dtype.serialize() + dtype_header, dtype_frames = dtype.device_serialize() fields[k] = ( dtype_header, (len(frames), len(frames) + len(dtype_frames)), ) frames.extend(dtype_frames) else: - fields[k] = pickle.dumps(dtype) + fields[k] = dtype.str header["fields"] = fields header["frame_count"] = len(frames) return header, frames @classmethod def deserialize(cls, header: dict, frames: list): - header, frames, klass = _decode_type(cls, header, frames) + _check_type(cls, header, frames) fields = {} for k, dtype in header["fields"].items(): if isinstance(dtype, tuple): dtype_header, (start, stop) = dtype - fields[k] = pickle.loads( - dtype_header["type-serialized"] - ).deserialize( + fields[k] = Serializable.device_deserialize( dtype_header, frames[start:stop], ) else: - fields[k] = pickle.loads(dtype) + fields[k] = np.dtype(dtype) return cls(fields) @cached_property @@ -838,7 +829,6 @@ def _from_decimal(cls, decimal): def serialize(self) -> tuple[dict, list]: return ( { - "type-serialized": pickle.dumps(type(self)), "precision": self.precision, "scale": self.scale, "frame_count": 0, @@ -848,11 +838,8 @@ def serialize(self) -> tuple[dict, list]: @classmethod def deserialize(cls, header: dict, frames: list): - header, frames, klass = _decode_type( - cls, header, frames, is_valid_class=issubclass - ) - klass = pickle.loads(header["type-serialized"]) - return klass(header["precision"], header["scale"]) + _check_type(cls, header, frames, is_valid_class=issubclass) + return cls(header["precision"], header["scale"]) def __eq__(self, other: Dtype) -> bool: if other is self: @@ -960,18 +947,17 @@ def __hash__(self): def serialize(self) -> tuple[dict, list]: header = { - "type-serialized": pickle.dumps(type(self)), - "fields": pickle.dumps((self.subtype, self.closed)), + "fields": (self.subtype.str, self.closed), "frame_count": 0, } return header, [] @classmethod def deserialize(cls, header: dict, frames: list): - header, frames, klass = _decode_type(cls, header, frames) - klass = pickle.loads(header["type-serialized"]) - subtype, closed = pickle.loads(header["fields"]) - return klass(subtype, closed=closed) + _check_type(cls, header, frames) + subtype, closed = header["fields"] + subtype = np.dtype(subtype) + return cls(subtype, closed=closed) def _is_categorical_dtype(obj): diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py index 84a3caf905f..00199cca828 100644 --- a/python/cudf/cudf/core/frame.py +++ b/python/cudf/cudf/core/frame.py @@ -3,7 +3,6 @@ from __future__ import annotations import operator -import pickle import warnings from collections import abc from typing import TYPE_CHECKING, Any, Literal @@ -24,6 +23,7 @@ from cudf.api.types import is_dtype_equal, is_scalar from cudf.core._compat import PANDAS_LT_300 from cudf.core._internals.search import search_sorted +from cudf.core.abc import Serializable from cudf.core.buffer import acquire_spill_lock from cudf.core.column import ( ColumnBase, @@ -47,7 +47,7 @@ # TODO: It looks like Frame is missing a declaration of `copy`, need to add -class Frame(BinaryOperand, Scannable): +class Frame(BinaryOperand, Scannable, Serializable): """A collection of Column objects with an optional index. Parameters @@ -97,37 +97,80 @@ def ndim(self) -> int: @_performance_tracking def serialize(self): # TODO: See if self._data can be serialized outright + frames = [] header = { - "type-serialized": pickle.dumps(type(self)), - "column_names": pickle.dumps(self._column_names), - "column_rangeindex": pickle.dumps(self._data.rangeindex), - "column_multiindex": pickle.dumps(self._data.multiindex), - "column_label_dtype": pickle.dumps(self._data.label_dtype), - "column_level_names": pickle.dumps(self._data._level_names), + "column_label_dtype": None, + "dtype-is-cudf-serialized": False, } - header["columns"], frames = serialize_columns(self._columns) + if (label_dtype := self._data.label_dtype) is not None: + try: + header["column_label_dtype"], frames = ( + label_dtype.device_serialize() + ) + header["dtype-is-cudf-serialized"] = True + except AttributeError: + header["column_label_dtype"] = label_dtype.str + + header["columns"], column_frames = serialize_columns(self._columns) + column_names, column_names_numpy_type = ( + zip( + *[ + (cname.item(), type(cname).__name__) + if isinstance(cname, np.generic) + else (cname, "") + for cname in self._column_names + ] + ) + if self._column_names + else ((), ()) + ) + header |= { + "column_names": column_names, + "column_names_numpy_type": column_names_numpy_type, + "column_rangeindex": self._data.rangeindex, + "column_multiindex": self._data.multiindex, + "column_level_names": self._data._level_names, + } + frames.extend(column_frames) + return header, frames @classmethod @_performance_tracking def deserialize(cls, header, frames): - cls_deserialize = pickle.loads(header["type-serialized"]) - column_names = pickle.loads(header["column_names"]) - columns = deserialize_columns(header["columns"], frames) kwargs = {} + dtype_header = header["column_label_dtype"] + if header["dtype-is-cudf-serialized"]: + count = dtype_header["frame_count"] + kwargs["label_dtype"] = cls.device_deserialize( + header, frames[:count] + ) + frames = frames[count:] + else: + kwargs["label_dtype"] = ( + np.dtype(dtype_header) if dtype_header is not None else None + ) + + columns = deserialize_columns(header["columns"], frames) for metadata in [ "rangeindex", "multiindex", - "label_dtype", "level_names", ]: key = f"column_{metadata}" if key in header: - kwargs[metadata] = pickle.loads(header[key]) + kwargs[metadata] = header[key] + + column_names = [ + getattr(np, cntype)(cname) if cntype != "" else cname + for cname, cntype in zip( + header["column_names"], header["column_names_numpy_type"] + ) + ] col_accessor = ColumnAccessor( data=dict(zip(column_names, columns)), **kwargs ) - return cls_deserialize._from_data(col_accessor) + return cls._from_data(col_accessor) @classmethod @_performance_tracking diff --git a/python/cudf/cudf/core/groupby/groupby.py b/python/cudf/cudf/core/groupby/groupby.py index 0f12f266a95..d4f3394833a 100644 --- a/python/cudf/cudf/core/groupby/groupby.py +++ b/python/cudf/cudf/core/groupby/groupby.py @@ -3,7 +3,6 @@ import copy import itertools -import pickle import textwrap import warnings from collections import abc @@ -1281,7 +1280,7 @@ def serialize(self): obj_header, obj_frames = self.obj.serialize() header["obj"] = obj_header - header["obj_type"] = pickle.dumps(type(self.obj)) + header["obj_type_name"] = type(self.obj).__name__ header["num_obj_frames"] = len(obj_frames) frames.extend(obj_frames) @@ -1296,7 +1295,7 @@ def serialize(self): def deserialize(cls, header, frames): kwargs = header["kwargs"] - obj_type = pickle.loads(header["obj_type"]) + obj_type = Serializable._name_type_map[header["obj_type_name"]] obj = obj_type.deserialize( header["obj"], frames[: header["num_obj_frames"]] ) @@ -3329,8 +3328,8 @@ def _handle_misc(self, by): def serialize(self): header = {} frames = [] - header["names"] = pickle.dumps(self.names) - header["_named_columns"] = pickle.dumps(self._named_columns) + header["names"] = self.names + header["_named_columns"] = self._named_columns column_header, column_frames = cudf.core.column.serialize_columns( self._key_columns ) @@ -3340,8 +3339,8 @@ def serialize(self): @classmethod def deserialize(cls, header, frames): - names = pickle.loads(header["names"]) - _named_columns = pickle.loads(header["_named_columns"]) + names = header["names"] + _named_columns = header["_named_columns"] key_columns = cudf.core.column.deserialize_columns( header["columns"], frames ) diff --git a/python/cudf/cudf/core/index.py b/python/cudf/cudf/core/index.py index cc3d8448151..eeb6e3bd547 100644 --- a/python/cudf/cudf/core/index.py +++ b/python/cudf/cudf/core/index.py @@ -3,7 +3,6 @@ from __future__ import annotations import operator -import pickle import warnings from collections.abc import Hashable, MutableMapping from functools import cache, cached_property @@ -497,9 +496,8 @@ def serialize(self): header["index_column"]["step"] = self.step frames = [] - header["name"] = pickle.dumps(self.name) - header["dtype"] = pickle.dumps(self.dtype) - header["type-serialized"] = pickle.dumps(type(self)) + header["name"] = self.name + header["dtype"] = self.dtype.str header["frame_count"] = 0 return header, frames @@ -507,11 +505,14 @@ def serialize(self): @_performance_tracking def deserialize(cls, header, frames): h = header["index_column"] - name = pickle.loads(header["name"]) + name = header["name"] start = h["start"] stop = h["stop"] step = h.get("step", 1) - return RangeIndex(start=start, stop=stop, step=step, name=name) + dtype = np.dtype(header["dtype"]) + return RangeIndex( + start=start, stop=stop, step=step, dtype=dtype, name=name + ) @property # type: ignore @_performance_tracking diff --git a/python/cudf/cudf/core/multiindex.py b/python/cudf/cudf/core/multiindex.py index 173d4e1c584..5a41a33e583 100644 --- a/python/cudf/cudf/core/multiindex.py +++ b/python/cudf/cudf/core/multiindex.py @@ -5,7 +5,6 @@ import itertools import numbers import operator -import pickle import warnings from functools import cached_property from typing import TYPE_CHECKING, Any @@ -921,15 +920,15 @@ def take(self, indices) -> Self: def serialize(self): header, frames = super().serialize() # Overwrite the names in _data with the true names. - header["column_names"] = pickle.dumps(self.names) + header["column_names"] = self.names return header, frames @classmethod @_performance_tracking def deserialize(cls, header, frames): # Spoof the column names to construct the frame, then set manually. - column_names = pickle.loads(header["column_names"]) - header["column_names"] = pickle.dumps(range(0, len(column_names))) + column_names = header["column_names"] + header["column_names"] = range(0, len(column_names)) obj = super().deserialize(header, frames) return obj._set_names(column_names) diff --git a/python/cudf/cudf/core/resample.py b/python/cudf/cudf/core/resample.py index d95d252559f..391ee31f125 100644 --- a/python/cudf/cudf/core/resample.py +++ b/python/cudf/cudf/core/resample.py @@ -15,7 +15,6 @@ # limitations under the License. from __future__ import annotations -import pickle import warnings from typing import TYPE_CHECKING @@ -26,6 +25,7 @@ import cudf from cudf._lib.column import Column +from cudf.core.abc import Serializable from cudf.core.buffer import acquire_spill_lock from cudf.core.groupby.groupby import ( DataFrameGroupBy, @@ -97,21 +97,21 @@ def serialize(self): header, frames = super().serialize() grouping_head, grouping_frames = self.grouping.serialize() header["grouping"] = grouping_head - header["resampler_type"] = pickle.dumps(type(self)) + header["resampler_type"] = type(self).__name__ header["grouping_frames_count"] = len(grouping_frames) frames.extend(grouping_frames) return header, frames @classmethod def deserialize(cls, header, frames): - obj_type = pickle.loads(header["obj_type"]) + obj_type = Serializable._name_type_map[header["obj_type_name"]] obj = obj_type.deserialize( header["obj"], frames[: header["num_obj_frames"]] ) grouping = _ResampleGrouping.deserialize( header["grouping"], frames[header["num_obj_frames"] :] ) - resampler_cls = pickle.loads(header["resampler_type"]) + resampler_cls = Serializable._name_type_map[header["resampler_type"]] out = resampler_cls.__new__(resampler_cls) out.grouping = grouping super().__init__(out, obj, by=grouping) @@ -163,8 +163,8 @@ def serialize(self): @classmethod def deserialize(cls, header, frames): - names = pickle.loads(header["names"]) - _named_columns = pickle.loads(header["_named_columns"]) + names = header["names"] + _named_columns = header["_named_columns"] key_columns = cudf.core.column.deserialize_columns( header["columns"], frames[: -header["__bin_labels_count"]] ) diff --git a/python/cudf/cudf/core/series.py b/python/cudf/cudf/core/series.py index be74b0f867a..647e20fc16b 100644 --- a/python/cudf/cudf/core/series.py +++ b/python/cudf/cudf/core/series.py @@ -4,7 +4,6 @@ import functools import inspect -import pickle import textwrap import warnings from collections import abc @@ -27,7 +26,6 @@ ) from cudf.core import indexing_utils from cudf.core._compat import PANDAS_LT_300 -from cudf.core.abc import Serializable from cudf.core.buffer import acquire_spill_lock from cudf.core.column import ( ColumnBase, @@ -414,7 +412,7 @@ def _loc_to_iloc(self, arg): return indices -class Series(SingleColumnFrame, IndexedFrame, Serializable): +class Series(SingleColumnFrame, IndexedFrame): """ One-dimensional GPU array (including time series). @@ -899,7 +897,7 @@ def hasnans(self): def serialize(self): header, frames = super().serialize() - header["index"], index_frames = self.index.serialize() + header["index"], index_frames = self.index.device_serialize() header["index_frame_count"] = len(index_frames) # For backwards compatibility with older versions of cuDF, index # columns are placed before data columns. @@ -915,8 +913,7 @@ def deserialize(cls, header, frames): header, frames[header["index_frame_count"] :] ) - idx_typ = pickle.loads(header["index"]["type-serialized"]) - index = idx_typ.deserialize(header["index"], frames[:index_nframes]) + index = cls.device_deserialize(header["index"], frames[:index_nframes]) obj.index = index return obj diff --git a/python/cudf/cudf/tests/data/pkl/stringColumnWithRangeIndex_cudf_23.12.pkl b/python/cudf/cudf/tests/data/pkl/stringColumnWithRangeIndex_cudf_23.12.pkl index 1ec077d10f7..64e06f0631d 100644 Binary files a/python/cudf/cudf/tests/data/pkl/stringColumnWithRangeIndex_cudf_23.12.pkl and b/python/cudf/cudf/tests/data/pkl/stringColumnWithRangeIndex_cudf_23.12.pkl differ diff --git a/python/cudf/cudf/tests/test_serialize.py b/python/cudf/cudf/tests/test_serialize.py index 68f2aaf9cab..b50ed04427f 100644 --- a/python/cudf/cudf/tests/test_serialize.py +++ b/python/cudf/cudf/tests/test_serialize.py @@ -7,6 +7,7 @@ import numpy as np import pandas as pd import pytest +from packaging import version import cudf from cudf.testing import _utils as utils, assert_eq @@ -149,13 +150,19 @@ def test_serialize(df, to_host): def test_serialize_dtype_error_checking(): dtype = cudf.IntervalDtype("float", "right") - header, frames = dtype.serialize() - with pytest.raises(AssertionError): - # Invalid number of frames - type(dtype).deserialize(header, [None] * (header["frame_count"] + 1)) + # Must call device_serialize (not serialize) to ensure that the type metadata is + # encoded in the header. + header, frames = dtype.device_serialize() with pytest.raises(AssertionError): # mismatching class cudf.StructDtype.deserialize(header, frames) + # The is-cuda flag list length must match the number of frames + header["is-cuda"] = [False] + with pytest.raises(AssertionError): + # Invalid number of frames + type(dtype).deserialize( + header, [np.zeros(1)] * (header["frame_count"] + 1) + ) def test_serialize_dataframe(): @@ -382,6 +389,10 @@ def test_serialize_string_check_buffer_sizes(): assert expect == got +@pytest.mark.skipif( + version.parse(np.__version__) < version.parse("2.0.0"), + reason="The serialization of numpy 2.0 types is incompatible with numpy 1.x", +) def test_deserialize_cudf_23_12(datadir): fname = datadir / "pkl" / "stringColumnWithRangeIndex_cudf_23.12.pkl" diff --git a/python/cudf/cudf/tests/test_struct.py b/python/cudf/cudf/tests/test_struct.py index 899d78c999b..b85943626a6 100644 --- a/python/cudf/cudf/tests/test_struct.py +++ b/python/cudf/cudf/tests/test_struct.py @@ -79,7 +79,7 @@ def test_series_construction_with_nulls(): ) def test_serialize_struct_dtype(fields): dtype = cudf.StructDtype(fields) - recreated = dtype.__class__.deserialize(*dtype.serialize()) + recreated = dtype.__class__.device_deserialize(*dtype.device_serialize()) assert recreated == dtype diff --git a/python/dask_cudf/dask_cudf/tests/test_distributed.py b/python/dask_cudf/dask_cudf/tests/test_distributed.py index d03180852eb..c28b7e49207 100644 --- a/python/dask_cudf/dask_cudf/tests/test_distributed.py +++ b/python/dask_cudf/dask_cudf/tests/test_distributed.py @@ -4,7 +4,7 @@ import pytest import dask -from dask import dataframe as dd +from dask import array as da, dataframe as dd from dask.distributed import Client from distributed.utils_test import cleanup, loop, loop_in_thread # noqa: F401 @@ -121,3 +121,17 @@ def test_unique(): ddf.x.unique().compute(), check_index=False, ) + + +def test_serialization_of_numpy_types(): + # Dask uses numpy integers as column names, which can break cudf serialization + with dask_cuda.LocalCUDACluster(n_workers=1) as cluster: + with Client(cluster): + with dask.config.set( + {"dataframe.backend": "cudf", "array.backend": "cupy"} + ): + rng = da.random.default_rng() + X_arr = rng.random((100, 10), chunks=(50, 10)) + X = dd.from_dask_array(X_arr) + X = X[X.columns[0]] + X.compute()