Skip to content

Commit

Permalink
Request UDF results as tiledb_json format by default. (#476)
Browse files Browse the repository at this point in the history
This is the final piece of the Pandas cross-compatibility puzzle:
we request the results of UDFs and task graphs as `tiledb_json`–encoded
values by default.  This means that Pandas dataframes within results
will be returned as Arrow and transparently decoded back to Pandas
for the user, along with all the other features TileDB JSON encoding
gives us.

There are some slight changes elsewhere:

- Tests which expect serialized tuples now get lists.
- We auto-detect the result format when fetching results by default.

We also unpin Pandas (yay) and restrict to Cloudpickle <3
(which, despite the ASCII art, I do not love to do).
  • Loading branch information
thetorpedodog authored Oct 18, 2023
1 parent 33fe826 commit 15788b6
Show file tree
Hide file tree
Showing 11 changed files with 51 additions and 39 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ dynamic = ["version"]
dependencies = [
"attrs>=21.4.0",
"certifi",
"cloudpickle>=1.4.1",
"cloudpickle>=1.4.1,<3",
"importlib-metadata",
"packaging",
"pandas>=1.2.4,<2",
"pandas>=1.2.4",
"pyarrow>=3.0.0",
"python-dateutil",
"six>=1.10",
Expand Down
6 changes: 3 additions & 3 deletions requirements-py3.7.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ certifi==2023.7.22
cloudpickle==2.2.1
importlib-metadata==6.7.0
numpy==1.21.6
packaging==23.1
packaging==23.2
pandas==1.3.5
pyarrow==12.0.1
python-dateutil==2.8.2
pytz==2023.3.post1
six==1.16.0
tblib==1.7.0
tiledb==0.23.0
tiledb==0.23.1
typing_extensions==4.7.1
urllib3==2.0.4
urllib3==2.0.7
xarray==0.20.2
zipp==3.15.0
17 changes: 9 additions & 8 deletions requirements-py3.9.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@ attrs==23.1.0
certifi==2023.7.22
cloudpickle==2.2.1
importlib-metadata==6.8.0
numpy==1.25.2
packaging==23.1
pandas==1.5.3
numpy==1.26.1
packaging==23.2
pandas==2.1.1
pyarrow==13.0.0
python-dateutil==2.8.2
pytz==2023.3.post1
six==1.16.0
tblib==1.7.0
tiledb==0.23.0
typing_extensions==4.7.1
urllib3==2.0.4
xarray==2023.8.0
zipp==3.16.2
tiledb==0.23.1
typing_extensions==4.8.0
tzdata==2023.3
urllib3==2.0.7
xarray==2023.9.0
zipp==3.17.0
21 changes: 13 additions & 8 deletions src/tiledb/cloud/_results/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
import attrs
import urllib3

from tiledb.cloud import client
from tiledb.cloud import rest_api
from tiledb.cloud import tiledb_cloud_error as tce
from tiledb.cloud._common import futures
from tiledb.cloud._common import utils
from tiledb.cloud._results import decoders
from tiledb.cloud._results import stored_params
from .. import client
from .. import rest_api
from .. import tiledb_cloud_error as tce
from .._common import futures
from .._common import utils
from . import codecs
from . import decoders
from . import stored_params

TASK_ID_HEADER = "X-TILEDB-CLOUD-TASK-ID"
_T = TypeVar("_T")
Expand Down Expand Up @@ -165,7 +166,9 @@ def _maybe_uuid(id_str: Optional[str]) -> Optional[uuid.UUID]:
return None


def fetch_remote(task_id: uuid.UUID, decoder: decoders.AbstractDecoder[_T]) -> _T:
def fetch_remote(
task_id: uuid.UUID, decoder: Optional[decoders.AbstractDecoder[Any]] = None
) -> object:
api_instance = client.build(rest_api.TasksApi)
try:
resp: urllib3.HTTPResponse = api_instance.task_id_result_get(
Expand All @@ -174,6 +177,8 @@ def fetch_remote(task_id: uuid.UUID, decoder: decoders.AbstractDecoder[_T]) -> _
)
except rest_api.ApiException as exc:
raise tce.check_exc(exc) from None
if decoder is None:
return codecs.BinaryBlob.from_response(resp).decode()
try:
return decoder.decode(resp.data)
finally:
Expand Down
2 changes: 1 addition & 1 deletion src/tiledb/cloud/taskgraphs/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def udf(
func: functions.Funcable[_T],
args: types.Arguments = types.Arguments(),
*,
result_format: Optional[str] = "python_pickle",
result_format: Optional[str] = "tiledb_json",
include_source: bool = True,
image_name: Optional[str] = None,
timeout: Union[datetime.timedelta, int, None] = None,
Expand Down
7 changes: 4 additions & 3 deletions src/tiledb/cloud/tasks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
import uuid
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Optional

from tiledb.cloud import array
from tiledb.cloud import client
Expand Down Expand Up @@ -134,10 +134,11 @@ def last_udf_task():
def fetch_results(
task_id: uuid.UUID,
*,
result_format: str = models.ResultFormat.NATIVE,
result_format: Optional[str] = None,
) -> Any:
"""Fetches the results of a previously-executed UDF or SQL query."""
return results.fetch_remote(task_id, decoders.Decoder(result_format))
decoder = None if result_format is None else decoders.Decoder(result_format)
return results.fetch_remote(task_id, decoder)


def fetch_results_pandas(
Expand Down
2 changes: 1 addition & 1 deletion src/tiledb/cloud/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def exec_base(
http_compressor: Optional[str] = "deflate",
include_source_lines: bool = True,
task_name: Optional[str] = None,
result_format: str = models.ResultFormat.NATIVE,
result_format: str = "tiledb_json",
result_format_version=None,
store_results: bool = False,
stored_param_uuids: Iterable[uuid.UUID] = (),
Expand Down
5 changes: 5 additions & 0 deletions tests/common/test_pickle_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Tuple

import numpy as np
import packaging.version as pkgver
import pandas as pd
import pytest

Expand Down Expand Up @@ -44,6 +45,10 @@ def import_tiledb_cloud():
import tiledb.cloud # noqa: F401


@pytest.mark.skipif(
pkgver.Version("2") <= pkgver.Version(pd.__version__),
reason="Pandas 2 is an unresolvable breaking change",
)
@pytest.mark.parametrize("pd_ver", ["1.2.4", "1.5.3"])
@pytest.mark.parametrize("name_want", RESULTS.items(), ids=lambda itm: itm[0])
def test_pandas_compat(pd_ver: str, name_want: Tuple[str, Any]) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/taskgraphs/delayed/test_delayed_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_basic_functions(self):
c = passthrough.set(name="c")(a, b)
d = delayed.udf(repr, name="d")(c)

self.assertEqual("(('a', ()), ((), 'b'))", d.compute(30))
self.assertEqual("[['a', []], [[], 'b']]", d.compute(30))

def test_two_delayeds(self):
d_repr = delayed.udf(repr)
Expand Down
20 changes: 10 additions & 10 deletions tests/taskgraphs/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def test_basic(self):
"language_version": utils.PYTHON_VERSION,
},
"executable_code": "gASVFAAAAAAAAACMCGJ1aWx0aW5zlIwDbGVulJOULg==",
"result_format": "python_pickle",
"result_format": "tiledb_json",
},
},
{
Expand Down Expand Up @@ -192,7 +192,7 @@ def test_basic(self):
"run_client_side": True,
},
"executable_code": "gASVQwAAAAAAAACMCGJ1aWx0aW5zlIwHZ2V0YXR0cpSTlIwbe2l0IXJ9IGhhcyBhIGxlbmd0aCBvZiB7bG59lIwGZm9ybWF0lIaUUpQu",
"result_format": "python_pickle",
"result_format": "tiledb_json",
},
},
],
Expand Down Expand Up @@ -328,7 +328,7 @@ def test_complex(self):
"language_version": utils.PYTHON_VERSION,
},
"executable_code": "gASVIwAAAAAAAACMCG9wZXJhdG9ylIwKaXRlbWdldHRlcpSTlIwBYZSFlFKULg==",
"result_format": "python_pickle",
"result_format": "tiledb_json",
},
},
{
Expand All @@ -352,7 +352,7 @@ def test_complex(self):
},
"executable_code": "gASVEQAAAAAAAACMBW51bXB5lIwDc3VtlJOULg==",
"source_text": functions.getsourcelines(numpy.sum),
"result_format": "python_pickle",
"result_format": "tiledb_json",
},
},
{
Expand All @@ -374,7 +374,7 @@ def test_complex(self):
"resource_class": "llama",
},
"executable_code": "gASVFAAAAAAAAACMCGJ1aWx0aW5zlIwDaW50lJOULg==",
"result_format": "python_pickle",
"result_format": "tiledb_json",
},
},
{
Expand Down Expand Up @@ -408,7 +408,7 @@ def test_complex(self):
"language_version": utils.PYTHON_VERSION,
},
"executable_code": "gASVQgAAAAAAAACMCGJ1aWx0aW5zlIwHZ2V0YXR0cpSTlIwac3VtIG9mIHtuYW1lIXJ9IGlzIHtzdW0hcn2UjAZmb3JtYXSUhpRSlC4=",
"result_format": "python_pickle",
"result_format": "tiledb_json",
},
},
{
Expand Down Expand Up @@ -446,7 +446,7 @@ def test_complex(self):
"run_client_side": True,
},
"executable_code": "gASVTgAAAAAAAACMCGJ1aWx0aW5zlIwHZ2V0YXR0cpSTlIwmYXJyYXkgeyFyfSBnYXZlIHJlc3VsdCB7fTsgc3FsIGdhdmUge32UjAZmb3JtYXSUhpRSlC4=",
"result_format": "python_pickle",
"result_format": "tiledb_json",
},
},
{
Expand All @@ -464,7 +464,7 @@ def test_complex(self):
],
"environment": {},
"registered_udf_name": "TileDB-Inc/example_registration",
"result_format": "python_pickle",
"result_format": "tiledb_json",
},
},
],
Expand Down Expand Up @@ -500,7 +500,7 @@ def test_name_collisions(self):
"language_version": utils.PYTHON_VERSION,
},
"executable_code": utils.b64_pickle(_codec.b64_str),
"result_format": "python_pickle",
"result_format": "tiledb_json",
"source_text": (
"def b64_str(val: bytes) -> str:\n"
' return base64.b64encode(val).decode("ascii")\n'
Expand Down Expand Up @@ -540,7 +540,7 @@ def test_name_collisions(self):
"language_version": utils.PYTHON_VERSION,
},
"executable_code": utils.b64_pickle(_codec.b64_str),
"result_format": "python_pickle",
"result_format": "tiledb_json",
# No source text here!
},
},
Expand Down
4 changes: 2 additions & 2 deletions tests/taskgraphs/test_client_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def make_nice(arr, sql_result):
exec = client_executor.LocalExecutor(grf)
exec.execute(frm=2, to=3)
self.assertEqual(
(
[
{
"a": [[2, 3], [6, 7], [10, 11], [14, 15]],
"cols": [[2, 3], [2, 3], [2, 3], [2, 3]],
Expand All @@ -412,7 +412,7 @@ def make_nice(arr, sql_result):
dict(a=14, cols=2, rows=4),
dict(a=15, cols=3, rows=4),
],
),
],
exec.node(result).result(30),
)
exec.wait(5)
Expand Down

0 comments on commit 15788b6

Please sign in to comment.