Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce basic "cudf" backend for Dask Expressions #14805

Merged
merged 136 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from 50 commits
Commits
Show all changes
136 commits
Select commit Hold shift + click to select a range
0da06d0
fix groupby get-group
rjzamora Jan 9, 2024
3a4a5a0
update copyright
rjzamora Jan 9, 2024
675964c
add new backend dispatching
rjzamora Jan 16, 2024
b26c270
add meta-based dispatching patch for demonstration
rjzamora Jan 17, 2024
b6a4cac
fix get_collection_type registration
rjzamora Jan 17, 2024
a6ee37a
re-org
rjzamora Jan 18, 2024
40687bd
rename
rjzamora Jan 19, 2024
04a1b24
comment
rjzamora Jan 19, 2024
b582d16
Merge remote-tracking branch 'upstream/branch-24.04' into new-dask-ex…
rjzamora Jan 19, 2024
9cd5c0b
Merge remote-tracking branch 'upstream/branch-24.04' into new-dask-ex…
rjzamora Jan 20, 2024
eb9fc88
update __dask_tokenize__
rjzamora Jan 22, 2024
61fae84
update __dask_tokenize__
rjzamora Jan 22, 2024
92a36d5
add test coverage
rjzamora Jan 22, 2024
4fd7db4
Merge remote-tracking branch 'upstream/branch-24.04' into fix-dask-no…
rjzamora Jan 23, 2024
c3d69f9
fix date
rjzamora Jan 23, 2024
ed11879
remove __new__ patchhing - we want to avoid checking _meta upon creation
rjzamora Jan 23, 2024
7b984b2
Avoid unnecessary caching
rjzamora Jan 23, 2024
ef1f82a
Merge branch 'branch-24.04' into new-dask-expr-backend
rjzamora Jan 23, 2024
dd32b2a
Merge branch 'branch-24.04' into fix-dask-normalize
rjzamora Jan 24, 2024
8b1da68
Merge branch 'branch-24.04' into new-dask-expr-backend
rjzamora Jan 24, 2024
befc090
Merge remote-tracking branch 'upstream/pandas_2.0_feature_branch' int…
rjzamora Jan 29, 2024
8850093
Merge branch 'branch-24.04' into fix-dask-normalize
rjzamora Jan 30, 2024
bcdd924
Merge branch 'branch-24.04' into fix-dask-normalize
rjzamora Jan 30, 2024
78975a4
Merge branch 'branch-24.04' into fix-dask-normalize
rjzamora Jan 30, 2024
399f618
Merge remote-tracking branch 'upstream/branch-24.04' into new-dask-ex…
rjzamora Jan 30, 2024
c01002b
move dask tests that depend on cudf code
rjzamora Jan 30, 2024
32d75fe
Merge branch 'fix-dask-normalize' into new-dask-expr-backend
rjzamora Jan 30, 2024
90defeb
revert change
rjzamora Jan 31, 2024
b43e7a2
Merge branch 'fix-dask-normalize' into new-dask-expr-backend
rjzamora Jan 31, 2024
4df89a0
add initial testing
rjzamora Jan 31, 2024
08e5761
remove comment
rjzamora Jan 31, 2024
c3f61e8
try to fix pre-commit failures
rjzamora Jan 31, 2024
58a3350
try to fix pre-commit failures again
rjzamora Jan 31, 2024
b385950
Use advice from Charles
rjzamora Jan 31, 2024
880e999
move tests and hold off on env changes for now
rjzamora Jan 31, 2024
fdc59a0
add back env changes with charles' suggestion
rjzamora Jan 31, 2024
793b56e
cw fix
rjzamora Jan 31, 2024
838a897
skip IO tests for now
rjzamora Jan 31, 2024
59e9b8b
remove dask-expr tests from conda-python-other-tests for now
rjzamora Jan 31, 2024
0aacea6
Merge remote-tracking branch 'upstream/branch-24.04' into new-dask-ex…
rjzamora Jan 31, 2024
916a0e4
remove extra cw change
rjzamora Jan 31, 2024
6a4bbae
fix cw dates
rjzamora Jan 31, 2024
9e296ff
fix cw dates
rjzamora Jan 31, 2024
3a2beb4
add csv test coverage
rjzamora Jan 31, 2024
b28b816
add parquet test coverage
rjzamora Feb 1, 2024
c9f522d
Merge remote-tracking branch 'upstream/branch-24.04' into new-dask-ex…
rjzamora Feb 1, 2024
68be23c
adjust top-level dask_cudf API
rjzamora Feb 1, 2024
542f696
improve/simplify UX
rjzamora Feb 1, 2024
a271851
align with import changes
rjzamora Feb 1, 2024
20d6f31
fix from_dict
rjzamora Feb 1, 2024
e922d46
remove extra utility
rjzamora Feb 1, 2024
2634419
fix direct support with dask-expr
rjzamora Feb 2, 2024
8e8005b
Merge remote-tracking branch 'upstream/branch-24.04' into new-dask-ex…
rjzamora Feb 2, 2024
c75e3b6
debugging groupby failures
rjzamora Feb 5, 2024
f417619
make sure series aggregate used the proper logic
rjzamora Feb 5, 2024
5cd5b5e
Merge branch 'branch-24.04' into fix-series-agg
rjzamora Feb 5, 2024
976144e
Merge remote-tracking branch 'upstream/branch-24.04' into new-dask-ex…
rjzamora Feb 5, 2024
13a7d5a
Merge remote-tracking branch 'origin/fix-series-agg' into new-dask-ex…
rjzamora Feb 5, 2024
3485237
further align groupby API and tests
rjzamora Feb 6, 2024
706c55b
Merge remote-tracking branch 'upstream/branch-24.04' into new-dask-ex…
rjzamora Feb 6, 2024
5f1fc47
Merge remote-tracking branch 'upstream/branch-24.04' into new-dask-ex…
rjzamora Feb 7, 2024
2046aa8
opt groupby
rjzamora Feb 8, 2024
dd9faf7
cleanup
rjzamora Feb 8, 2024
146fd7d
only use optimized single-agg code path for collect
rjzamora Feb 9, 2024
6de8946
expand test coverage
rjzamora Feb 9, 2024
65dd6ad
add sort testing (limited support for now)
rjzamora Feb 9, 2024
d0118ec
Merge remote-tracking branch 'upstream/branch-24.04' into new-dask-ex…
rjzamora Feb 9, 2024
1a8eabe
further expand test coverage
rjzamora Feb 9, 2024
e68d107
fix normalization
rjzamora Feb 12, 2024
faa54fc
add test coverage for accessorts, but some tun-explained test failure…
rjzamora Feb 12, 2024
569d01f
allow test_struct.py tests
rjzamora Feb 12, 2024
ccf14fd
Merge remote-tracking branch 'upstream/branch-24.04' into new-dask-ex…
rjzamora Feb 12, 2024
1ace543
Merge remote-tracking branch 'upstream/branch-24.04' into new-dask-ex…
rjzamora Feb 14, 2024
f7cd393
xfail test that needs newer version of dask-expr
rjzamora Feb 14, 2024
2091cb0
Merge branch 'branch-24.04' into new-dask-expr-backend
rjzamora Feb 16, 2024
d741963
Merge branch 'branch-24.04' into new-dask-expr-backend
rjzamora Feb 20, 2024
0334951
Merge branch 'branch-24.04' into new-dask-expr-backend
rjzamora Feb 22, 2024
286f60e
Merge branch 'branch-24.04' into new-dask-expr-backend
rjzamora Feb 22, 2024
6844647
simnplify dask-expr dependency
rjzamora Feb 22, 2024
18c6dbf
remove extra variable
rjzamora Feb 22, 2024
d8f77e9
Merge branch 'branch-24.04' into new-dask-expr-backend
rjzamora Feb 22, 2024
8c60154
handle newer versions of dask-expr
rjzamora Feb 23, 2024
0d5f6e7
Merge branch 'new-dask-expr-backend' of github.com:rjzamora/cudf into…
rjzamora Feb 26, 2024
ae44004
Merge remote-tracking branch 'upstream/branch-24.04' into new-dask-ex…
rjzamora Feb 26, 2024
5a6b1c2
add conda testing for dask-expr
rjzamora Feb 26, 2024
ed9d2aa
change conda test
rjzamora Feb 26, 2024
bdf6fdc
fix cw date
rjzamora Feb 26, 2024
4f2fa9f
Fix GroupBy.get_group and GroupBy.indices
wence- Feb 26, 2024
d9d077f
Add test
wence- Feb 26, 2024
e51ebcc
Merge remote-tracking branch 'upstream/branch-24.04' into wence/fix/1…
wence- Feb 27, 2024
8a584c2
Merge remote-tracking branch 'upstream/branch-24.04' into new-dask-ex…
rjzamora Feb 27, 2024
15c6b23
small code-review changes
rjzamora Feb 27, 2024
3475276
Merge branch 'branch-24.04' into wence/fix/14955
wence- Feb 27, 2024
da4422a
Merge remote-tracking branch 'upstream/branch-24.04' into new-dask-ex…
rjzamora Feb 28, 2024
a2ad915
Merge branch 'branch-24.04' into wence/fix/14955
rjzamora Feb 28, 2024
f536837
remove optimized groupby code path since it is a lot of fragile code …
rjzamora Feb 28, 2024
7614926
simplify tokenization fix
rjzamora Feb 28, 2024
d3a97c7
add comments to expression patchhes
rjzamora Feb 28, 2024
daafd09
Merge remote-tracking branch 'upstream/branch-24.04' into new-dask-ex…
rjzamora Feb 28, 2024
02547e9
roll back unused groupby changes
rjzamora Feb 28, 2024
e2ed6ff
remove space
rjzamora Feb 28, 2024
f4baec8
Merge remote-tracking branch 'upstream/branch-24.04' into new-dask-ex…
rjzamora Feb 28, 2024
ef2fefc
update dependencies
rjzamora Feb 28, 2024
c9087ba
Merge branch 'branch-24.04' into wence/fix/14955
vyasr Feb 29, 2024
25f8d2b
Merge remote-tracking branch 'upstream/branch-24.04' into new-dask-ex…
rjzamora Feb 29, 2024
47392c3
reduce diff
rjzamora Feb 29, 2024
481c655
adjust tokenization change
rjzamora Feb 29, 2024
576a601
Merge remote-tracking branch 'upstream/branch-24.04' into new-dask-ex…
rjzamora Feb 29, 2024
026b98b
xfail groupby test that fails for older dask-expr version
rjzamora Feb 29, 2024
67f8be9
remove explicit dask-expr dependency as it should now be covered by r…
rjzamora Feb 29, 2024
9db0b5f
Merge remote-tracking branch 'upstream/branch-24.04' into new-dask-ex…
rjzamora Feb 29, 2024
f6aa070
Merge branch 'branch-24.04' into wence/fix/14955
rjzamora Feb 29, 2024
4a0763f
Merge remote-tracking branch 'wence/wence/fix/14955' into new-dask-ex…
rjzamora Feb 29, 2024
be88773
remove more code - sits on top of #15143
rjzamora Feb 29, 2024
32bfff3
roll back shuffle-group fix for now
rjzamora Mar 1, 2024
79a32d3
Merge remote-tracking branch 'upstream/branch-24.04' into new-dask-ex…
rjzamora Mar 1, 2024
185b506
Merge branch 'branch-24.04' into new-dask-expr-backend
rjzamora Mar 4, 2024
ac4b9f0
Merge branch 'branch-24.04' into new-dask-expr-backend
rjzamora Mar 4, 2024
a681fa8
Merge branch 'branch-24.04' into new-dask-expr-backend
rjzamora Mar 5, 2024
ff4e438
update cw date
rjzamora Mar 5, 2024
363d203
improve check
rjzamora Mar 5, 2024
78be3f9
Merge remote-tracking branch 'upstream/branch-24.04' into new-dask-ex…
rjzamora Mar 5, 2024
a330d06
Apply suggestions from code review
rjzamora Mar 5, 2024
2a3dd01
Merge remote-tracking branch 'upstream/branch-24.04' into new-dask-ex…
rjzamora Mar 5, 2024
6d8e7ac
make ci happy
rjzamora Mar 5, 2024
7e57099
Merge branch 'branch-24.04' into new-dask-expr-backend
rjzamora Mar 5, 2024
cfb67cb
Merge branch 'branch-24.04' into new-dask-expr-backend
rjzamora Mar 5, 2024
dddd130
Merge branch 'branch-24.04' into new-dask-expr-backend
rjzamora Mar 6, 2024
4432a4d
Merge branch 'branch-24.04' into new-dask-expr-backend
rjzamora Mar 6, 2024
3dfd819
Merge branch 'branch-24.04' into new-dask-expr-backend
rjzamora Mar 7, 2024
738d940
Merge branch 'branch-24.04' into new-dask-expr-backend
rjzamora Mar 8, 2024
fef0d67
Merge branch 'branch-24.04' into new-dask-expr-backend
rjzamora Mar 8, 2024
d222235
Merge branch 'branch-24.04' into new-dask-expr-backend
rjzamora Mar 10, 2024
acaf920
Merge branch 'branch-24.04' into new-dask-expr-backend
bdice Mar 11, 2024
29977b0
Merge branch 'branch-24.04' into new-dask-expr-backend
rjzamora Mar 11, 2024
fcd4f06
Merge branch 'branch-24.04' into new-dask-expr-backend
rjzamora Mar 11, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion ci/test_wheel_dask_cudf.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.

set -eou pipefail

Expand Down Expand Up @@ -28,3 +28,7 @@ python -m pip install $(echo ./dist/dask_cudf*.whl)[test]

# Run tests in dask_cudf/tests and dask_cudf/io/tests
python -m pytest -n 8 ./python/dask_cudf/dask_cudf/

# Run tests in dask_cudf/tests and dask_cudf/io/tests with dask-expr
echo "Running dask-cudf tests with dask-expr enabled..."
DASK_DATAFRAME__QUERY_PLANNING=True python -m pytest -n 8 ./python/dask_cudf/dask_cudf/
1 change: 1 addition & 0 deletions conda/environments/all_cuda-118_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -102,5 +102,6 @@ dependencies:
- typing_extensions>=4.0.0
- zlib>=1.2.13
- pip:
- git+https://github.com/dask-contrib/dask-expr.git@b588a9e15e90e0567061664ffc01374786686e20
- git+https://github.com/python-streamz/streamz.git@master
name: all_cuda-118_arch-x86_64
1 change: 1 addition & 0 deletions conda/environments/all_cuda-120_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,6 @@ dependencies:
- typing_extensions>=4.0.0
- zlib>=1.2.13
- pip:
- git+https://github.com/dask-contrib/dask-expr.git@b588a9e15e90e0567061664ffc01374786686e20
- git+https://github.com/python-streamz/streamz.git@master
name: all_cuda-120_arch-x86_64
12 changes: 12 additions & 0 deletions dependencies.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,18 @@ dependencies:
packages:
- dask-cuda==24.4.*
- *numba
- output_types: conda
packages:
- pip
- pip:
# This should eventually move to rapids-dask-dependency
- &dask_expr_tip git+https://github.com/dask-contrib/dask-expr.git@b588a9e15e90e0567061664ffc01374786686e20
- output_types: requirements
packages:
- *dask_expr_tip
- output_types: pyproject
packages:
- dask-expr@git+https://github.com/dask-contrib/dask-expr.git@b588a9e15e90e0567061664ffc01374786686e20
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to use dask_expr_tip here to cut down on duplication, but have no idea if something like

Suggested change
- dask-expr@git+https://github.com/dask-contrib/dask-expr.git@b588a9e15e90e0567061664ffc01374786686e20
- dask-expr @ *dask_expr_tip

Would work here as I'm not the most familiar with rapids-dependency-file-generator

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds fine to me. As you know, I'm just deferring to your advice for these kinds of changes :)

Another background note: I'm expecting dask-expr to be copied entirely into dask/dask proper before the "query-planning" default is changed. When that happens, we should be able to remove all this stuff.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm - I'm not having much luck with this. Not sure if we can avoid the duplication in this case.

depends_on_cudf:
common:
- output_types: conda
Expand Down
56 changes: 48 additions & 8 deletions python/dask_cudf/dask_cudf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,69 @@
# Copyright (c) 2018-2023, NVIDIA CORPORATION.
# Copyright (c) 2018-2024, NVIDIA CORPORATION.

import dask.dataframe as dd
from dask import config
from dask.dataframe import from_delayed

import cudf

from . import backends
from ._version import __git_commit__, __version__
from .core import DataFrame, Series, concat, from_cudf, from_dask_dataframe
from .groupby import groupby_agg
from .io import read_csv, read_json, read_orc, read_text, to_orc
from .core import concat, from_cudf, from_dask_dataframe
rjzamora marked this conversation as resolved.
Show resolved Hide resolved
from .expr import DASK_EXPR_ENABLED


def read_csv(*args, **kwargs):
with config.set({"dataframe.backend": "cudf"}):
return dd.read_csv(*args, **kwargs)


def read_json(*args, **kwargs):
with config.set({"dataframe.backend": "cudf"}):
return dd.read_json(*args, **kwargs)


def read_orc(*args, **kwargs):
with config.set({"dataframe.backend": "cudf"}):
return dd.read_orc(*args, **kwargs)


def read_parquet(*args, **kwargs):
with config.set({"dataframe.backend": "cudf"}):
return dd.read_parquet(*args, **kwargs)


def raise_not_implemented_error(attr_name):
def inner_func(*args, **kwargs):
raise NotImplementedError(
f"Top-level {attr_name} API is not available for dask-expr."
)

return inner_func


if DASK_EXPR_ENABLED:
from .expr._collection import DataFrame, Index, Series

groupby_agg = raise_not_implemented_error("groupby_agg")
read_text = raise_not_implemented_error("read_text")
to_orc = raise_not_implemented_error("to_orc")
else:
from .core import DataFrame, Index, Series
from .groupby import groupby_agg
from .io import read_text, to_orc

try:
from .io import read_parquet
except ImportError:
pass

__all__ = [
"DataFrame",
"Series",
"Index",
"from_cudf",
"from_dask_dataframe",
"concat",
"from_delayed",
]


if not hasattr(cudf.DataFrame, "mean"):
cudf.DataFrame.mean = None
del cudf
63 changes: 59 additions & 4 deletions python/dask_cudf/dask_cudf/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,13 +625,68 @@ def read_csv(*args, **kwargs):

@staticmethod
def read_hdf(*args, **kwargs):
from dask_cudf import from_dask_dataframe

# HDF5 reader not yet implemented in cudf
warnings.warn(
"read_hdf is not yet implemented in cudf/dask_cudf. "
"Moving to cudf from pandas. Expect poor performance!"
)
return from_dask_dataframe(
_default_backend(dd.read_hdf, *args, **kwargs)
return _default_backend(dd.read_hdf, *args, **kwargs).to_backend(
"cudf"
)


# Define "cudf" backend entrypoint for dask-expr
class CudfDXBackendEntrypoint(DataFrameBackendEntrypoint):
rjzamora marked this conversation as resolved.
Show resolved Hide resolved
"""Backend-entrypoint class for Dask-Expressions

This class is registered under the name "cudf" for the
``dask-expr.dataframe.backends`` entrypoint in ``setup.cfg``.
Dask-DataFrame will use the methods defined in this class
in place of ``dask_expr.<creation-method>`` when the
"dataframe.backend" configuration is set to "cudf":

Examples
--------
>>> import dask
>>> import dask_expr
>>> with dask.config.set({"dataframe.backend": "cudf"}):
... ddf = dx.from_dict({"a": range(10)})
>>> type(ddf._meta)
<class 'cudf.core.dataframe.DataFrame'>
"""

@classmethod
def to_backend_dispatch(cls):
return CudfBackendEntrypoint.to_backend_dispatch()

@classmethod
def to_backend(cls, *args, **kwargs):
return CudfBackendEntrypoint.to_backend(*args, **kwargs)

@staticmethod
def from_dict(
data,
npartitions,
orient="columns",
dtype=None,
columns=None,
constructor=cudf.DataFrame,
):
import dask_expr as dx

return _default_backend(
dx.from_dict,
data,
npartitions=npartitions,
orient=orient,
dtype=dtype,
columns=columns,
constructor=constructor,
)


# Import/register cudf-specific classes for dask-expr
try:
import dask_cudf.expr # noqa: F401
except ImportError:
pass
17 changes: 13 additions & 4 deletions python/dask_cudf/dask_cudf/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pandas as pd
from tlz import partition_all

from dask import dataframe as dd
from dask import config, dataframe as dd
from dask.base import normalize_token, tokenize
from dask.dataframe.core import (
Scalar,
Expand Down Expand Up @@ -690,13 +690,20 @@ def from_cudf(data, npartitions=None, chunksize=None, sort=True, name=None):
"dask_cudf does not support MultiIndex Dataframes."
)

name = name or ("from_cudf-" + tokenize(data, npartitions or chunksize))
# Dask-expr doesn't support the `name` argument
name = {}
if not config.get("dataframe.query-planning", False):
name = {
"name": name
or ("from_cudf-" + tokenize(data, npartitions or chunksize))
}

return dd.from_pandas(
data,
npartitions=npartitions,
chunksize=chunksize,
sort=sort,
name=name,
**name,
)


Expand All @@ -711,7 +718,9 @@ def from_cudf(data, npartitions=None, chunksize=None, sort=True, name=None):
rather than pandas objects.\n
"""
)
+ textwrap.dedent(dd.from_pandas.__doc__)
# TODO: `dd.from_pandas.__doc__` is empty when
# `DASK_DATAFRAME__QUERY_PLANNING=True`
rjzamora marked this conversation as resolved.
Show resolved Hide resolved
+ textwrap.dedent(dd.from_pandas.__doc__ or "")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly haven't had time to understand this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because the doc-string is currently missing in dask-expr

)


Expand Down
17 changes: 17 additions & 0 deletions python/dask_cudf/dask_cudf/expr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from dask import config

DASK_EXPR_ENABLED = False
if config.get("dataframe.query-planning", False):
# Make sure custom expressions and collections are defined
try:
import dask_cudf.expr._collection
import dask_cudf.expr._expr

DASK_EXPR_ENABLED = True
except ImportError:
# Dask Expressions not installed.
# Dask DataFrame should have already thrown an error
# before we got here.
pass
66 changes: 66 additions & 0 deletions python/dask_cudf/dask_cudf/expr/_collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from dask_expr import (
DataFrame as DXDataFrame,
FrameBase,
Index as DXIndex,
Series as DXSeries,
get_collection_type,
)

from dask import config

import cudf

##
## Custom collection classes
##


class DataFrame(DXDataFrame):
@classmethod
def from_dict(cls, *args, **kwargs):
with config.set({"dataframe.backend": "cudf"}):
return DXDataFrame.from_dict(*args, **kwargs)

def groupby(
self,
by,
group_keys=True,
sort=None,
observed=None,
dropna=None,
**kwargs,
):
from dask_cudf.expr._groupby import GroupBy

if isinstance(by, FrameBase) and not isinstance(by, DXSeries):
raise ValueError(
f"`by` must be a column name or list of columns, got {by}."
)
Comment on lines +69 to +72
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question (non-blocking): What are the things that you might otherwise group on that we don't support?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is mostly guarding against a pd.Grouper argument.


return GroupBy(
self,
by,
group_keys=group_keys,
sort=sort,
observed=observed,
dropna=dropna,
**kwargs,
)


class Series(DXSeries):
def groupby(self, by, **kwargs):
from dask_cudf.expr._groupby import SeriesGroupBy

return SeriesGroupBy(self, by, **kwargs)


class Index(DXIndex):
pass # Same as pandas (for now)


get_collection_type.register(cudf.DataFrame, lambda _: DataFrame)
get_collection_type.register(cudf.Series, lambda _: Series)
get_collection_type.register(cudf.BaseIndex, lambda _: Index)
34 changes: 34 additions & 0 deletions python/dask_cudf/dask_cudf/expr/_expr.py
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would consider the logic in this class to be the most "fragile". We are literally patching Expr sub-classes to do things that will work for cudf-backed data. By "surgically" patching specific classes, we dramatically reduce the amount of code we need in dask-cudf.

Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from dask_expr._cumulative import CumulativeBlockwise, TakeLast

##
## Custom expression patching
##


class PatchCumulativeBlockwise(CumulativeBlockwise):
rjzamora marked this conversation as resolved.
Show resolved Hide resolved
@property
def _args(self) -> list:
return self.operands[:1]

@property
def _kwargs(self) -> dict:
# Must pass axis and skipna as kwargs in cudf
return {"axis": self.axis, "skipna": self.skipna}


CumulativeBlockwise._args = PatchCumulativeBlockwise._args
CumulativeBlockwise._kwargs = PatchCumulativeBlockwise._kwargs


def _takelast(a, skipna=True):
if not len(a):
return a
if skipna:
a = a.bfill()
# Cannot use `squeeze` with cudf
return a.tail(n=1).iloc[0]
rjzamora marked this conversation as resolved.
Show resolved Hide resolved


TakeLast.operation = staticmethod(_takelast)
rjzamora marked this conversation as resolved.
Show resolved Hide resolved
Loading
Loading