Skip to content

Commit

Permalink
Reorganize UDF data codecs to prepare for better serialization. (#463)
Browse files Browse the repository at this point in the history
To better separate concerns and eliminate cross-package circular
dependencies, this change moves the main handling of the `tiledb_json`
format to the `_results` package, and restructures the way codecs are
defined and handled.

This will allow us to introduce `tiledb_json` as a format for
serializing UDF results, which will allow us to use the `tiledb_json`
format to serialize Pandas dataframes, with custom encoding/decoding
logic based on Arrow, instead of using native pickles.  This will allow
us to offer client/server compatibility across Pandas versions, without
the hacks we currently resort to.
  • Loading branch information
thetorpedodog authored Sep 26, 2023
1 parent 718b045 commit 7db00a4
Show file tree
Hide file tree
Showing 16 changed files with 540 additions and 417 deletions.
216 changes: 216 additions & 0 deletions src/tiledb/cloud/_results/codecs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import abc
import base64
import json
from typing import Any, Generic, Tuple, Type, TypeVar

import attrs
import cloudpickle
import pyarrow
import urllib3
from typing_extensions import Self

# This is a circular dependency since we need to be able to decode `tiledb_json`
# format data.
from . import tiledb_json
from . import types

_ARROW_VERSION = pyarrow.MetadataVersion.V5
_PICKLE_PROTOCOL = 4
_T = TypeVar("_T")


class Codec(Generic[_T], metaclass=abc.ABCMeta):
"""Translates objects to bytes and vice versa. Purely classmethods."""

__slots__ = ()

NAME: str
"""The name to use for the codec, as used in ``result_format``."""

MIME: str
"""The MIME type identifying this codec."""

@classmethod
@abc.abstractmethod
def encode(cls, obj: _T) -> bytes:
raise NotImplementedError()

@classmethod
@abc.abstractmethod
def decode(cls, data: bytes) -> _T:
raise NotImplementedError()

@classmethod
def decode_base64(cls, data: str) -> _T:
data_bytes = base64.b64decode(data)
return cls.decode(data_bytes)

@classmethod
def encode_base64(cls, obj: _T) -> str:
data_bytes = cls.encode(obj)
return base64.b64encode(data_bytes).decode("utf-8")


class ArrowCodec(Codec[pyarrow.Table]):
"""Encodes Arrow data into its default stream format."""

NAME = "arrow"
MIME = "application/vnd.apache.arrow.stream"

@classmethod
def encode(cls, tbl: pyarrow.Table) -> bytes:
sink = pyarrow.BufferOutputStream()
writer = pyarrow.RecordBatchStreamWriter(
sink,
tbl.schema,
options=pyarrow.ipc.IpcWriteOptions(
metadata_version=_ARROW_VERSION,
compression="zstd",
),
)
writer.write(tbl)
return sink.getvalue()

@classmethod
def decode(cls, data: bytes) -> pyarrow.Table:
# If a UDF didn't return any rows, there will not have been any batches
# of data to write to the output, and thus it will not include any content
# at all. (SQL queries will include headers.)
if not data:
# In this case, we need to return an empty table.
return pyarrow.Table.from_pydict({})
reader = pyarrow.RecordBatchStreamReader(data)
return reader.read_all()


class BytesCodec(Codec[bytes]):
"""Does nothing to bytes."""

NAME = "bytes"
MIME = "application/octet-stream"

@classmethod
def encode(cls, obj: bytes) -> bytes:
return obj

@classmethod
def decode(cls, data: bytes) -> bytes:
return data


class JSONCodec(Codec[object]):
"""Dumps/loads JSON."""

NAME = "json"
MIME = "application/json"

@classmethod
def encode(cls, obj: object) -> bytes:
return json.dumps(obj).encode("utf-8")

@classmethod
def decode(cls, data: bytes) -> object:
return json.loads(data)


class PickleCodec(Codec[object]):
"""Pickles objects using CloudPickle."""

NAME = "python_pickle"
MIME = "application/vnd.tiledb.python-pickle"

@classmethod
def encode(cls, obj: object) -> bytes:
return cloudpickle.dumps(obj, protocol=_PICKLE_PROTOCOL)

@classmethod
def decode(cls, data: bytes) -> object:
return cloudpickle.loads(data)


class TileDBJSONCodec(Codec[object]):
"""Serializes objects with TileDB JSON."""

NAME = "tiledb_json"
MIME = "application/vnd.tiledb.udf-data+json"

@classmethod
def encode(cls, obj: object) -> bytes:
return tiledb_json.dumps(obj)

@classmethod
def decode(cls, data: bytes) -> object:
return tiledb_json.loads(data)


ALL_CODECS: Tuple[Type[Codec[Any]], ...] = (
ArrowCodec,
BytesCodec,
JSONCodec,
PickleCodec,
TileDBJSONCodec,
)
"""Every codec we have."""
CODECS_BY_FORMAT = {c.NAME: c for c in ALL_CODECS}
CODECS_BY_FORMAT["native"] = PickleCodec
CODECS_BY_MIME = {c.MIME: c for c in ALL_CODECS}


@attrs.define(frozen=True, slots=False)
class BinaryBlob:
"""Container for a binary-encoded value, decoded on-demand.
This is used to store results obtained from the server, such that it is not
necessary to decode them between stages, and they only are decoded upon
request.
"""

format: str
"""The TileDB Cloud name of the data's format (see ``CODECS_BY_FORMAT``)."""
data: bytes
"""The binary data itself."""

@classmethod
def from_response(cls, resp: urllib3.HTTPResponse) -> Self:
"""Reads a urllib3 response into an encoded result."""
full_mime = resp.getheader("Content-type") or "application/octet-stream"
mime, _, _ = full_mime.partition(";")
mime = mime.strip()
try:
format_name = CODECS_BY_MIME[mime].NAME
except KeyError:
format_name = "mime:" + mime
data = resp.data
return cls(format_name, data)

def decode(self) -> types.NativeValue:
"""Decodes this result into native Python data."""
# This is not lock-protected because we're ok with decoding twice.
try:
return self.__dict__["_decoded"]
except KeyError:
pass
try:
loader = CODECS_BY_FORMAT[self.format].decode
except KeyError:
raise ValueError(f"Cannot decode {self.format!r} data")
self.__dict__["_decoded"] = loader(self.data)
return self.__dict__["_decoded"]

def _tdb_to_json(self) -> types.TileDBJSONValue:
return types.TileDBJSONValue(
{
"__tdbudf__": "immediate",
"format": self.format,
"base64_data": base64.b64encode(self.data).decode("ascii"),
}
)

@classmethod
def of(cls, obj: Any) -> Self:
"""Turns a non–JSON-encodable object into a ``BinaryBlob``."""
if isinstance(obj, bytes):
return cls("bytes", obj)
if isinstance(obj, pyarrow.Table):
return cls("arrow", ArrowCodec.encode(obj))
return cls("python_pickle", PickleCodec.encode(obj))
26 changes: 4 additions & 22 deletions src/tiledb/cloud/_results/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

import abc
import dataclasses
import json
from typing import TYPE_CHECKING, Any, Generic, TypeVar

import cloudpickle
import pyarrow

from tiledb.cloud import tiledb_cloud_error as tce
from tiledb.cloud.rest_api import models

from . import codecs

if TYPE_CHECKING:
import pandas

Expand All @@ -25,24 +25,6 @@ def decode(self, data: bytes) -> Any:
raise NotImplementedError()


def _load_arrow(data: bytes) -> pyarrow.Table:
# If a UDF didn't return any rows, there will not have been any batches
# of data to write to the output, and thus it will not include any content
# at all. (SQL queries will include headers.)
if not data:
# In this case, we need to return an empty table.
return pyarrow.Table.from_pydict({})
reader = pyarrow.RecordBatchStreamReader(data)
return reader.read_all()


_DECODE_FNS = {
models.ResultFormat.NATIVE: cloudpickle.loads,
models.ResultFormat.JSON: json.loads,
models.ResultFormat.ARROW: _load_arrow,
}


@dataclasses.dataclass(frozen=True)
class Decoder(AbstractDecoder[_T]):
"""General decoder for the formats we support.
Expand All @@ -55,10 +37,10 @@ class Decoder(AbstractDecoder[_T]):

def decode(self, data: bytes) -> _T:
try:
decoder = _DECODE_FNS[self.format]
codec = codecs.CODECS_BY_FORMAT[self.format]
except KeyError:
raise tce.TileDBCloudError(f"{self.format!r} is not a valid result format.")
return decoder(data)
return codec.decode(data)


@dataclasses.dataclass(frozen=True)
Expand Down
125 changes: 125 additions & 0 deletions src/tiledb/cloud/_results/tiledb_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import json
from typing import Any, Dict, Optional

from .._common import visitor
from . import codecs

SENTINEL_KEY = "__tdbudf__"
ESCAPE_CODE = "__escape__"


class Encoder(visitor.ReplacingVisitor):
"""Turns arbitrary Python values into TileDB JSON.
This escapes arbitrary native values so that they can be JSON-serialized.
It should only be used with ``NativeValue``s—that is, values that are
already JSON-serializable, like ``RegisteredArg``s or ``CallArg``s, should
*not* be passed to an ``Escaper``. The base implementation will return
fully self-contained JSON-serializable objects, i.e. ``CallArg``s.
"""

def maybe_replace(self, arg) -> Optional[visitor.Replacement]:
if is_jsonable_shallow(arg):
if isinstance(arg, dict):
if SENTINEL_KEY in arg:
return visitor.Replacement(
{
SENTINEL_KEY: ESCAPE_CODE,
ESCAPE_CODE: {k: self.visit(v) for (k, v) in arg.items()},
}
)
return None
return visitor.Replacement(codecs.BinaryBlob.of(arg)._tdb_to_json())


class Decoder(visitor.ReplacingVisitor):
"""A general-purpose replacer to decode sentinel-containing structures.
This descends through data structures and replaces dictionaries containing
``__tdbudf__`` values with the unescaped values. This base implementation
handles only the basics; you can create a derived version to handle specific
situations (building arguments, replacing values, etc.).
The data that is returned from this is generally a ``types.NativeValue``.
"""

def maybe_replace(self, arg) -> Optional[visitor.Replacement]:
if not isinstance(arg, dict):
return None
try:
sentinel_name = arg[SENTINEL_KEY]
except KeyError:
return None
return self._replace_sentinel(sentinel_name, arg)

def _replace_sentinel(
self,
kind: str,
value: Dict[str, Any],
) -> Optional[visitor.Replacement]:
"""The base implementation of a sentinel-replacer.
It is passed the kind and value of a ``__tdbudf__``–containing object::
# Given this:
the_object = {"__tdbudf__": "node_data", "data": "abc"}
# This will be called:
self._replace_sentinel("node_data", the_object)
This implementation handles replacing values that do not require any
external information. Derived implementations should handle their own
keys and end with a call to
``return super()._replace_sentinel(kind, value)``.
"""
if kind == ESCAPE_CODE:
# An escaped value.
inner_value = value[ESCAPE_CODE]
# We can't just visit `inner_value` here, since `inner_value`
# is the thing which has the ESCAPE_CODE key that is being escaped.
return visitor.Replacement(
{k: self.visit(v) for (k, v) in inner_value.items()}
)
if kind == "immediate":
# "immediate" values are values of the format
fmt = value["format"]
base64d = value["base64_data"]
return visitor.Replacement(
codecs.CODECS_BY_FORMAT[fmt].decode_base64(base64d)
)
raise ValueError(f"Unknown sentinel type {kind!r}")


def dumps(obj: object) -> bytes:
"""Dumps an object to TileDB UDF–encoded JSON."""
enc = Encoder()
result_json = enc.visit(obj)
return json.dumps(result_json).encode("utf-8")


def loads(data: bytes) -> object:
"""Loads TileDB UDF–encoded JSON to an object."""
data_json = json.loads(data)
dec = Decoder()
return dec.visit(data_json)


_NATIVE_JSONABLE = (
str,
int,
bool,
float,
type(None),
list,
tuple,
)
"""Types whose direct values can always be turned into JSON."""


def is_jsonable_shallow(obj) -> bool:
if isinstance(obj, _NATIVE_JSONABLE):
return True
if not isinstance(obj, dict):
# Apart from the above types, only dicts are JSONable.
return False
# For a dict to be JSONable, all keys must be strings.
return all(isinstance(key, str) for key in obj)
Loading

0 comments on commit 7db00a4

Please sign in to comment.