Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable dask_cudf json and s3 tests with query-planning on #15408

Merged
merged 2 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion python/dask_cudf/dask_cudf/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import warnings
from collections.abc import Iterator
from functools import partial

import cupy as cp
import numpy as np
Expand Down Expand Up @@ -484,7 +485,6 @@ def sizeof_cudf_series_index(obj):
def _simple_cudf_encode(_):
# Basic pickle-based encoding for a partd k-v store
import pickle
from functools import partial

import partd

Expand Down Expand Up @@ -686,6 +686,19 @@ def from_dict(
constructor=constructor,
)

@staticmethod
def read_json(*args, engine="auto", **kwargs):
return _default_backend(
dd.read_json,
*args,
engine=(
partial(cudf.read_json, engine=engine)
if isinstance(engine, str)
else engine
),
**kwargs,
)


# Import/register cudf-specific classes for dask-expr
try:
Expand Down
4 changes: 2 additions & 2 deletions python/dask_cudf/dask_cudf/io/tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import dask_cudf
from dask_cudf.tests.utils import skip_dask_expr

# No dask-expr support
pytestmark = skip_dask_expr()
# No dask-expr support for dask_expr<=1.0.5
pytestmark = skip_dask_expr(lt_version="1.0.5+a")


def test_read_json_backend_dispatch(tmp_path):
Expand Down
2 changes: 1 addition & 1 deletion python/dask_cudf/dask_cudf/io/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ def test_check_file_size(tmpdir):
dask_cudf.io.read_parquet(fn, check_file_size=1).compute()


@xfail_dask_expr("HivePartitioning cannot be hashed")
@xfail_dask_expr("HivePartitioning cannot be hashed", lt_version="1.0")
def test_null_partition(tmpdir):
import pyarrow as pa
from pyarrow.dataset import HivePartitioning
Expand Down
6 changes: 1 addition & 5 deletions python/dask_cudf/dask_cudf/io/tests/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@
import pytest

import dask_cudf
from dask_cudf.tests.utils import skip_dask_expr

# No dask-expr support
pytestmark = skip_dask_expr()

moto = pytest.importorskip("moto", minversion="3.1.6")
boto3 = pytest.importorskip("boto3")
Expand Down Expand Up @@ -111,7 +107,7 @@ def test_read_csv(s3_base, s3so):
s3_base=s3_base, bucket="daskcsv", files={"a.csv": b"a,b\n1,2\n3,4\n"}
):
df = dask_cudf.read_csv(
"s3://daskcsv/*.csv", chunksize="50 B", storage_options=s3so
"s3://daskcsv/*.csv", blocksize="50 B", storage_options=s3so
)
assert df.a.sum().compute() == 4

Expand Down
24 changes: 20 additions & 4 deletions python/dask_cudf/dask_cudf/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,21 @@
import numpy as np
import pandas as pd
import pytest
from packaging.version import Version

import dask.dataframe as dd

import cudf

from dask_cudf.expr import QUERY_PLANNING_ON

if QUERY_PLANNING_ON:
import dask_expr

DASK_EXPR_VERSION = Version(dask_expr.__version__)
else:
DASK_EXPR_VERSION = None


def _make_random_frame(nelem, npartitions=2, include_na=False):
df = pd.DataFrame(
Expand All @@ -27,9 +35,17 @@ def _make_random_frame(nelem, npartitions=2, include_na=False):
_default_reason = "Not compatible with dask-expr"


def skip_dask_expr(reason=_default_reason):
return pytest.mark.skipif(QUERY_PLANNING_ON, reason=reason)
def skip_dask_expr(reason=_default_reason, lt_version=None):
if lt_version is not None:
skip = QUERY_PLANNING_ON and DASK_EXPR_VERSION < Version(lt_version)
else:
skip = QUERY_PLANNING_ON
return pytest.mark.skipif(skip, reason=reason)


def xfail_dask_expr(reason=_default_reason):
return pytest.mark.xfail(QUERY_PLANNING_ON, reason=reason)
def xfail_dask_expr(reason=_default_reason, lt_version=None):
if lt_version is not None:
xfail = QUERY_PLANNING_ON and DASK_EXPR_VERSION < Version(lt_version)
else:
xfail = QUERY_PLANNING_ON
return pytest.mark.xfail(xfail, reason=reason)
Loading