-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Reorganize UDF data codecs to prepare for better serialization. (#463)
To better separate concerns and eliminate cross-package circular dependencies, this change moves the main handling of the `tiledb_json` format to the `_results` package, and restructures the way codecs are defined and handled. This will allow us to introduce `tiledb_json` as a format for serializing UDF results, which will allow us to use the `tiledb_json` format to serialize Pandas dataframes, with custom encoding/decoding logic based on Arrow, instead of using native pickles. This will allow us to offer client/server compatibility across Pandas versions, without the hacks we currently resort to.
- Loading branch information
1 parent
718b045
commit 7db00a4
Showing
16 changed files
with
540 additions
and
417 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
import abc | ||
import base64 | ||
import json | ||
from typing import Any, Generic, Tuple, Type, TypeVar | ||
|
||
import attrs | ||
import cloudpickle | ||
import pyarrow | ||
import urllib3 | ||
from typing_extensions import Self | ||
|
||
# This is a circular dependency since we need to be able to decode `tiledb_json` | ||
# format data. | ||
from . import tiledb_json | ||
from . import types | ||
|
||
_ARROW_VERSION = pyarrow.MetadataVersion.V5 | ||
_PICKLE_PROTOCOL = 4 | ||
_T = TypeVar("_T") | ||
|
||
|
||
class Codec(Generic[_T], metaclass=abc.ABCMeta): | ||
"""Translates objects to bytes and vice versa. Purely classmethods.""" | ||
|
||
__slots__ = () | ||
|
||
NAME: str | ||
"""The name to use for the codec, as used in ``result_format``.""" | ||
|
||
MIME: str | ||
"""The MIME type identifying this codec.""" | ||
|
||
@classmethod | ||
@abc.abstractmethod | ||
def encode(cls, obj: _T) -> bytes: | ||
raise NotImplementedError() | ||
|
||
@classmethod | ||
@abc.abstractmethod | ||
def decode(cls, data: bytes) -> _T: | ||
raise NotImplementedError() | ||
|
||
@classmethod | ||
def decode_base64(cls, data: str) -> _T: | ||
data_bytes = base64.b64decode(data) | ||
return cls.decode(data_bytes) | ||
|
||
@classmethod | ||
def encode_base64(cls, obj: _T) -> str: | ||
data_bytes = cls.encode(obj) | ||
return base64.b64encode(data_bytes).decode("utf-8") | ||
|
||
|
||
class ArrowCodec(Codec[pyarrow.Table]): | ||
"""Encodes Arrow data into its default stream format.""" | ||
|
||
NAME = "arrow" | ||
MIME = "application/vnd.apache.arrow.stream" | ||
|
||
@classmethod | ||
def encode(cls, tbl: pyarrow.Table) -> bytes: | ||
sink = pyarrow.BufferOutputStream() | ||
writer = pyarrow.RecordBatchStreamWriter( | ||
sink, | ||
tbl.schema, | ||
options=pyarrow.ipc.IpcWriteOptions( | ||
metadata_version=_ARROW_VERSION, | ||
compression="zstd", | ||
), | ||
) | ||
writer.write(tbl) | ||
return sink.getvalue() | ||
|
||
@classmethod | ||
def decode(cls, data: bytes) -> pyarrow.Table: | ||
# If a UDF didn't return any rows, there will not have been any batches | ||
# of data to write to the output, and thus it will not include any content | ||
# at all. (SQL queries will include headers.) | ||
if not data: | ||
# In this case, we need to return an empty table. | ||
return pyarrow.Table.from_pydict({}) | ||
reader = pyarrow.RecordBatchStreamReader(data) | ||
return reader.read_all() | ||
|
||
|
||
class BytesCodec(Codec[bytes]): | ||
"""Does nothing to bytes.""" | ||
|
||
NAME = "bytes" | ||
MIME = "application/octet-stream" | ||
|
||
@classmethod | ||
def encode(cls, obj: bytes) -> bytes: | ||
return obj | ||
|
||
@classmethod | ||
def decode(cls, data: bytes) -> bytes: | ||
return data | ||
|
||
|
||
class JSONCodec(Codec[object]): | ||
"""Dumps/loads JSON.""" | ||
|
||
NAME = "json" | ||
MIME = "application/json" | ||
|
||
@classmethod | ||
def encode(cls, obj: object) -> bytes: | ||
return json.dumps(obj).encode("utf-8") | ||
|
||
@classmethod | ||
def decode(cls, data: bytes) -> object: | ||
return json.loads(data) | ||
|
||
|
||
class PickleCodec(Codec[object]): | ||
"""Pickles objects using CloudPickle.""" | ||
|
||
NAME = "python_pickle" | ||
MIME = "application/vnd.tiledb.python-pickle" | ||
|
||
@classmethod | ||
def encode(cls, obj: object) -> bytes: | ||
return cloudpickle.dumps(obj, protocol=_PICKLE_PROTOCOL) | ||
|
||
@classmethod | ||
def decode(cls, data: bytes) -> object: | ||
return cloudpickle.loads(data) | ||
|
||
|
||
class TileDBJSONCodec(Codec[object]): | ||
"""Serializes objects with TileDB JSON.""" | ||
|
||
NAME = "tiledb_json" | ||
MIME = "application/vnd.tiledb.udf-data+json" | ||
|
||
@classmethod | ||
def encode(cls, obj: object) -> bytes: | ||
return tiledb_json.dumps(obj) | ||
|
||
@classmethod | ||
def decode(cls, data: bytes) -> object: | ||
return tiledb_json.loads(data) | ||
|
||
|
||
ALL_CODECS: Tuple[Type[Codec[Any]], ...] = ( | ||
ArrowCodec, | ||
BytesCodec, | ||
JSONCodec, | ||
PickleCodec, | ||
TileDBJSONCodec, | ||
) | ||
"""Every codec we have.""" | ||
CODECS_BY_FORMAT = {c.NAME: c for c in ALL_CODECS} | ||
CODECS_BY_FORMAT["native"] = PickleCodec | ||
CODECS_BY_MIME = {c.MIME: c for c in ALL_CODECS} | ||
|
||
|
||
@attrs.define(frozen=True, slots=False) | ||
class BinaryBlob: | ||
"""Container for a binary-encoded value, decoded on-demand. | ||
This is used to store results obtained from the server, such that it is not | ||
necessary to decode them between stages, and they only are decoded upon | ||
request. | ||
""" | ||
|
||
format: str | ||
"""The TileDB Cloud name of the data's format (see ``CODECS_BY_FORMAT``).""" | ||
data: bytes | ||
"""The binary data itself.""" | ||
|
||
@classmethod | ||
def from_response(cls, resp: urllib3.HTTPResponse) -> Self: | ||
"""Reads a urllib3 response into an encoded result.""" | ||
full_mime = resp.getheader("Content-type") or "application/octet-stream" | ||
mime, _, _ = full_mime.partition(";") | ||
mime = mime.strip() | ||
try: | ||
format_name = CODECS_BY_MIME[mime].NAME | ||
except KeyError: | ||
format_name = "mime:" + mime | ||
data = resp.data | ||
return cls(format_name, data) | ||
|
||
def decode(self) -> types.NativeValue: | ||
"""Decodes this result into native Python data.""" | ||
# This is not lock-protected because we're ok with decoding twice. | ||
try: | ||
return self.__dict__["_decoded"] | ||
except KeyError: | ||
pass | ||
try: | ||
loader = CODECS_BY_FORMAT[self.format].decode | ||
except KeyError: | ||
raise ValueError(f"Cannot decode {self.format!r} data") | ||
self.__dict__["_decoded"] = loader(self.data) | ||
return self.__dict__["_decoded"] | ||
|
||
def _tdb_to_json(self) -> types.TileDBJSONValue: | ||
return types.TileDBJSONValue( | ||
{ | ||
"__tdbudf__": "immediate", | ||
"format": self.format, | ||
"base64_data": base64.b64encode(self.data).decode("ascii"), | ||
} | ||
) | ||
|
||
@classmethod | ||
def of(cls, obj: Any) -> Self: | ||
"""Turns a non–JSON-encodable object into a ``BinaryBlob``.""" | ||
if isinstance(obj, bytes): | ||
return cls("bytes", obj) | ||
if isinstance(obj, pyarrow.Table): | ||
return cls("arrow", ArrowCodec.encode(obj)) | ||
return cls("python_pickle", PickleCodec.encode(obj)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import json | ||
from typing import Any, Dict, Optional | ||
|
||
from .._common import visitor | ||
from . import codecs | ||
|
||
SENTINEL_KEY = "__tdbudf__" | ||
ESCAPE_CODE = "__escape__" | ||
|
||
|
||
class Encoder(visitor.ReplacingVisitor): | ||
"""Turns arbitrary Python values into TileDB JSON. | ||
This escapes arbitrary native values so that they can be JSON-serialized. | ||
It should only be used with ``NativeValue``s—that is, values that are | ||
already JSON-serializable, like ``RegisteredArg``s or ``CallArg``s, should | ||
*not* be passed to an ``Escaper``. The base implementation will return | ||
fully self-contained JSON-serializable objects, i.e. ``CallArg``s. | ||
""" | ||
|
||
def maybe_replace(self, arg) -> Optional[visitor.Replacement]: | ||
if is_jsonable_shallow(arg): | ||
if isinstance(arg, dict): | ||
if SENTINEL_KEY in arg: | ||
return visitor.Replacement( | ||
{ | ||
SENTINEL_KEY: ESCAPE_CODE, | ||
ESCAPE_CODE: {k: self.visit(v) for (k, v) in arg.items()}, | ||
} | ||
) | ||
return None | ||
return visitor.Replacement(codecs.BinaryBlob.of(arg)._tdb_to_json()) | ||
|
||
|
||
class Decoder(visitor.ReplacingVisitor): | ||
"""A general-purpose replacer to decode sentinel-containing structures. | ||
This descends through data structures and replaces dictionaries containing | ||
``__tdbudf__`` values with the unescaped values. This base implementation | ||
handles only the basics; you can create a derived version to handle specific | ||
situations (building arguments, replacing values, etc.). | ||
The data that is returned from this is generally a ``types.NativeValue``. | ||
""" | ||
|
||
def maybe_replace(self, arg) -> Optional[visitor.Replacement]: | ||
if not isinstance(arg, dict): | ||
return None | ||
try: | ||
sentinel_name = arg[SENTINEL_KEY] | ||
except KeyError: | ||
return None | ||
return self._replace_sentinel(sentinel_name, arg) | ||
|
||
def _replace_sentinel( | ||
self, | ||
kind: str, | ||
value: Dict[str, Any], | ||
) -> Optional[visitor.Replacement]: | ||
"""The base implementation of a sentinel-replacer. | ||
It is passed the kind and value of a ``__tdbudf__``–containing object:: | ||
# Given this: | ||
the_object = {"__tdbudf__": "node_data", "data": "abc"} | ||
# This will be called: | ||
self._replace_sentinel("node_data", the_object) | ||
This implementation handles replacing values that do not require any | ||
external information. Derived implementations should handle their own | ||
keys and end with a call to | ||
``return super()._replace_sentinel(kind, value)``. | ||
""" | ||
if kind == ESCAPE_CODE: | ||
# An escaped value. | ||
inner_value = value[ESCAPE_CODE] | ||
# We can't just visit `inner_value` here, since `inner_value` | ||
# is the thing which has the ESCAPE_CODE key that is being escaped. | ||
return visitor.Replacement( | ||
{k: self.visit(v) for (k, v) in inner_value.items()} | ||
) | ||
if kind == "immediate": | ||
# "immediate" values are values of the format | ||
fmt = value["format"] | ||
base64d = value["base64_data"] | ||
return visitor.Replacement( | ||
codecs.CODECS_BY_FORMAT[fmt].decode_base64(base64d) | ||
) | ||
raise ValueError(f"Unknown sentinel type {kind!r}") | ||
|
||
|
||
def dumps(obj: object) -> bytes: | ||
"""Dumps an object to TileDB UDF–encoded JSON.""" | ||
enc = Encoder() | ||
result_json = enc.visit(obj) | ||
return json.dumps(result_json).encode("utf-8") | ||
|
||
|
||
def loads(data: bytes) -> object: | ||
"""Loads TileDB UDF–encoded JSON to an object.""" | ||
data_json = json.loads(data) | ||
dec = Decoder() | ||
return dec.visit(data_json) | ||
|
||
|
||
_NATIVE_JSONABLE = ( | ||
str, | ||
int, | ||
bool, | ||
float, | ||
type(None), | ||
list, | ||
tuple, | ||
) | ||
"""Types whose direct values can always be turned into JSON.""" | ||
|
||
|
||
def is_jsonable_shallow(obj) -> bool: | ||
if isinstance(obj, _NATIVE_JSONABLE): | ||
return True | ||
if not isinstance(obj, dict): | ||
# Apart from the above types, only dicts are JSONable. | ||
return False | ||
# For a dict to be JSONable, all keys must be strings. | ||
return all(isinstance(key, str) for key in obj) |
Oops, something went wrong.