diff --git a/python/dask_cudf/dask_cudf/expr/_collection.py b/python/dask_cudf/dask_cudf/expr/_collection.py index b2f92aeddda..799e6eddab3 100644 --- a/python/dask_cudf/dask_cudf/expr/_collection.py +++ b/python/dask_cudf/dask_cudf/expr/_collection.py @@ -108,3 +108,34 @@ class Index(DXIndex): get_collection_type.register(cudf.DataFrame, lambda _: DataFrame) get_collection_type.register(cudf.Series, lambda _: Series) get_collection_type.register(cudf.BaseIndex, lambda _: Index) + + +## +## Support conversion to GPU-backed Array collections +## + + +try: + from dask_expr._backends import create_array_collection + + @get_collection_type.register_lazy("cupy") + def _register_cupy(): + import cupy + + @get_collection_type.register(cupy.ndarray) + def get_collection_type_cupy_array(_): + return create_array_collection + + @get_collection_type.register_lazy("cupyx") + def _register_cupyx(): + # Needed for cuml + from cupyx.scipy.sparse import spmatrix + + @get_collection_type.register(spmatrix) + def get_collection_type_csr_matrix(_): + return create_array_collection + +except ImportError: + # Older version of dask-expr. + # Implicit conversion to array wont work. + pass diff --git a/python/dask_cudf/dask_cudf/tests/test_core.py b/python/dask_cudf/dask_cudf/tests/test_core.py index 8a2f3414fd1..c6918c94559 100644 --- a/python/dask_cudf/dask_cudf/tests/test_core.py +++ b/python/dask_cudf/dask_cudf/tests/test_core.py @@ -913,3 +913,37 @@ def test_categorical_dtype_round_trip(): actual = ds.compute() expected = pds.compute() assert actual.dtype.ordered == expected.dtype.ordered + + +def test_implicit_array_conversion_cupy(): + s = cudf.Series(range(10)) + ds = dask_cudf.from_cudf(s, npartitions=2) + + def func(x): + return x.values + + # Need to compute the dask collection for now. + # See: https://github.com/dask/dask/issues/11017 + result = ds.map_partitions(func, meta=s.values).compute() + expect = func(s) + + dask.array.assert_eq(result, expect) + + +def test_implicit_array_conversion_cupy_sparse(): + cupyx = pytest.importorskip("cupyx") + + s = cudf.Series(range(10), dtype="float32") + ds = dask_cudf.from_cudf(s, npartitions=2) + + def func(x): + return cupyx.scipy.sparse.csr_matrix(x.values) + + # Need to compute the dask collection for now. + # See: https://github.com/dask/dask/issues/11017 + result = ds.map_partitions(func, meta=s.values).compute() + expect = func(s) + + # NOTE: The calculation here doesn't need to make sense. + # We just need to make sure we get the right type back. + assert type(result) == type(expect)