diff --git a/python/dask_cudf/dask_cudf/backends.py b/python/dask_cudf/dask_cudf/backends.py index c7b4a1c4c6a..d05be30602e 100644 --- a/python/dask_cudf/dask_cudf/backends.py +++ b/python/dask_cudf/dask_cudf/backends.py @@ -2,6 +2,7 @@ import warnings from collections.abc import Iterator +from functools import partial import cupy as cp import numpy as np @@ -484,7 +485,6 @@ def sizeof_cudf_series_index(obj): def _simple_cudf_encode(_): # Basic pickle-based encoding for a partd k-v store import pickle - from functools import partial import partd @@ -686,6 +686,19 @@ def from_dict( constructor=constructor, ) + @staticmethod + def read_json(*args, engine="auto", **kwargs): + return _default_backend( + dd.read_json, + *args, + engine=( + partial(cudf.read_json, engine=engine) + if isinstance(engine, str) + else engine + ), + **kwargs, + ) + # Import/register cudf-specific classes for dask-expr try: diff --git a/python/dask_cudf/dask_cudf/io/tests/test_json.py b/python/dask_cudf/dask_cudf/io/tests/test_json.py index a2b1d7fc114..8dcf3f05e89 100644 --- a/python/dask_cudf/dask_cudf/io/tests/test_json.py +++ b/python/dask_cudf/dask_cudf/io/tests/test_json.py @@ -12,8 +12,8 @@ import dask_cudf from dask_cudf.tests.utils import skip_dask_expr -# No dask-expr support -pytestmark = skip_dask_expr() +# No dask-expr support for dask_expr<=1.0.5 +pytestmark = skip_dask_expr(lt_version="1.0.5+a") def test_read_json_backend_dispatch(tmp_path): diff --git a/python/dask_cudf/dask_cudf/io/tests/test_parquet.py b/python/dask_cudf/dask_cudf/io/tests/test_parquet.py index de2a735b2ce..df41ef77b7c 100644 --- a/python/dask_cudf/dask_cudf/io/tests/test_parquet.py +++ b/python/dask_cudf/dask_cudf/io/tests/test_parquet.py @@ -535,7 +535,7 @@ def test_check_file_size(tmpdir): dask_cudf.io.read_parquet(fn, check_file_size=1).compute() -@xfail_dask_expr("HivePartitioning cannot be hashed") +@xfail_dask_expr("HivePartitioning cannot be hashed", lt_version="1.0") def test_null_partition(tmpdir): import pyarrow as pa from pyarrow.dataset import HivePartitioning diff --git a/python/dask_cudf/dask_cudf/io/tests/test_s3.py b/python/dask_cudf/dask_cudf/io/tests/test_s3.py index f4a6fabdb60..a67404da4fe 100644 --- a/python/dask_cudf/dask_cudf/io/tests/test_s3.py +++ b/python/dask_cudf/dask_cudf/io/tests/test_s3.py @@ -10,10 +10,6 @@ import pytest import dask_cudf -from dask_cudf.tests.utils import skip_dask_expr - -# No dask-expr support -pytestmark = skip_dask_expr() moto = pytest.importorskip("moto", minversion="3.1.6") boto3 = pytest.importorskip("boto3") @@ -111,7 +107,7 @@ def test_read_csv(s3_base, s3so): s3_base=s3_base, bucket="daskcsv", files={"a.csv": b"a,b\n1,2\n3,4\n"} ): df = dask_cudf.read_csv( - "s3://daskcsv/*.csv", chunksize="50 B", storage_options=s3so + "s3://daskcsv/*.csv", blocksize="50 B", storage_options=s3so ) assert df.a.sum().compute() == 4 diff --git a/python/dask_cudf/dask_cudf/tests/utils.py b/python/dask_cudf/dask_cudf/tests/utils.py index e838b8d63bc..1ca1758736b 100644 --- a/python/dask_cudf/dask_cudf/tests/utils.py +++ b/python/dask_cudf/dask_cudf/tests/utils.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd import pytest +from packaging.version import Version import dask.dataframe as dd @@ -10,6 +11,13 @@ from dask_cudf.expr import QUERY_PLANNING_ON +if QUERY_PLANNING_ON: + import dask_expr + + DASK_EXPR_VERSION = Version(dask_expr.__version__) +else: + DASK_EXPR_VERSION = None + def _make_random_frame(nelem, npartitions=2, include_na=False): df = pd.DataFrame( @@ -27,9 +35,17 @@ def _make_random_frame(nelem, npartitions=2, include_na=False): _default_reason = "Not compatible with dask-expr" -def skip_dask_expr(reason=_default_reason): - return pytest.mark.skipif(QUERY_PLANNING_ON, reason=reason) +def skip_dask_expr(reason=_default_reason, lt_version=None): + if lt_version is not None: + skip = QUERY_PLANNING_ON and DASK_EXPR_VERSION < Version(lt_version) + else: + skip = QUERY_PLANNING_ON + return pytest.mark.skipif(skip, reason=reason) -def xfail_dask_expr(reason=_default_reason): - return pytest.mark.xfail(QUERY_PLANNING_ON, reason=reason) +def xfail_dask_expr(reason=_default_reason, lt_version=None): + if lt_version is not None: + xfail = QUERY_PLANNING_ON and DASK_EXPR_VERSION < Version(lt_version) + else: + xfail = QUERY_PLANNING_ON + return pytest.mark.xfail(xfail, reason=reason)