Skip to content

Commit

Permalink
Remove cudf._lib.transform in favor of inlining pylibcudf (#17505)
Browse files Browse the repository at this point in the history
Contributes to #17317

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)

Approvers:
  - Matthew Murray (https://github.com/Matt711)

URL: #17505
  • Loading branch information
mroeschke authored Dec 9, 2024
1 parent 5b412dc commit 9df95d1
Show file tree
Hide file tree
Showing 14 changed files with 85 additions and 165 deletions.
1 change: 0 additions & 1 deletion python/cudf/cudf/_lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ set(cython_sources
stream_compaction.pyx
string_casting.pyx
strings_udf.pyx
transform.pyx
types.pyx
utils.pyx
)
Expand Down
113 changes: 0 additions & 113 deletions python/cudf/cudf/_lib/transform.pyx

This file was deleted.

4 changes: 1 addition & 3 deletions python/cudf/cudf/core/column/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import cudf
from cudf import _lib as libcudf
from cudf._lib.transform import bools_to_mask
from cudf.core._internals import unary
from cudf.core.column import column
from cudf.core.column.methods import ColumnMethods
Expand Down Expand Up @@ -775,12 +774,11 @@ def to_pandas(
raise NotImplementedError(f"{arrow_type=} is not implemented.")

if self.categories.dtype.kind == "f":
new_mask = bools_to_mask(self.notnull())
col = type(self)(
data=self.data, # type: ignore[arg-type]
size=self.size,
dtype=self.dtype,
mask=new_mask,
mask=self.notnull().fillna(False).as_mask(),
children=self.children,
)
else:
Expand Down
34 changes: 25 additions & 9 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
drop_duplicates,
drop_nulls,
)
from cudf._lib.transform import bools_to_mask
from cudf._lib.types import size_type_dtype
from cudf.api.types import (
_is_non_decimal_numeric_dtype,
Expand Down Expand Up @@ -373,10 +372,14 @@ def from_arrow(cls, array: pa.Array) -> ColumnBase:

return result._with_type_metadata(cudf_dtype_from_pa_type(array.type))

@acquire_spill_lock()
def _get_mask_as_column(self) -> ColumnBase:
return libcudf.transform.mask_to_bools(
self.base_mask, self.offset, self.offset + len(self)
plc_column = plc.transform.mask_to_bools(
self.base_mask.get_ptr(mode="read"), # type: ignore[union-attr]
self.offset,
self.offset + len(self),
)
return type(self).from_pylibcudf(plc_column)

@cached_property
def memory_usage(self) -> int:
Expand Down Expand Up @@ -981,11 +984,14 @@ def as_mask(self) -> Buffer:
-------
Buffer
"""

if self.has_nulls():
raise ValueError("Column must have no nulls.")

return bools_to_mask(self)
with acquire_spill_lock():
mask, _ = plc.transform.bools_to_mask(
self.to_pylibcudf(mode="read")
)
return as_buffer(mask)

@property
def is_unique(self) -> bool:
Expand Down Expand Up @@ -1514,6 +1520,18 @@ def _return_sentinel_column():
)
return codes.fillna(na_sentinel.value)

def one_hot_encode(
self, categories: ColumnBase
) -> abc.Generator[ColumnBase]:
plc_table = plc.transform.one_hot_encode(
self.to_pylibcudf(mode="read"),
categories.to_pylibcudf(mode="read"),
)
return (
type(self).from_pylibcudf(col, data_ptr_exposed=True)
for col in plc_table.columns()
)


def _has_any_nan(arbitrary: pd.Series | np.ndarray) -> bool:
"""Check if an object dtype Series or array contains NaN."""
Expand Down Expand Up @@ -2093,8 +2111,7 @@ def as_column(
)
# Consider NaT as NA in the mask
# but maintain NaT as a value
bool_mask = as_column(~is_nat)
mask = as_buffer(bools_to_mask(bool_mask))
mask = as_column(~is_nat).as_mask()
buffer = as_buffer(arbitrary.view("|u1"))
col = build_column(data=buffer, mask=mask, dtype=arbitrary.dtype)
if dtype:
Expand Down Expand Up @@ -2264,8 +2281,7 @@ def _mask_from_cuda_array_interface_desc(obj, cai_mask) -> Buffer:
)
return as_buffer(data=desc["data"][0], size=mask_size, owner=obj)
elif typecode == "b":
col = as_column(cai_mask)
return bools_to_mask(col)
return as_column(cai_mask).as_mask()
else:
raise NotImplementedError(f"Cannot infer mask from typestr {typestr}")

Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/column/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def from_sequences(
data=None,
size=len(arbitrary),
dtype=cudf.ListDtype(data_col.dtype),
mask=cudf._lib.transform.bools_to_mask(as_column(mask_col)),
mask=as_column(mask_col).as_mask(),
offset=0,
null_count=0,
children=(offset_col, data_col),
Expand Down
30 changes: 25 additions & 5 deletions python/cudf/cudf/core/column/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,24 @@

import numpy as np
import pandas as pd
from numba.np import numpy_support
from typing_extensions import Self

import pylibcudf
import pylibcudf as plc

import cudf
import cudf.core.column.column as column
import cudf.core.column.string as string
from cudf import _lib as libcudf
from cudf.api.types import is_integer, is_scalar
from cudf.core._internals import binaryop, unary
from cudf.core.buffer import acquire_spill_lock, as_buffer
from cudf.core.column.column import ColumnBase, as_column
from cudf.core.column.numerical_base import NumericalBaseColumn
from cudf.core.dtypes import CategoricalDtype
from cudf.core.mixins import BinaryOperand
from cudf.errors import MixedTypeError
from cudf.utils import cudautils
from cudf.utils.dtypes import (
find_common_type,
min_column_type,
Expand Down Expand Up @@ -179,13 +182,27 @@ def __setitem__(self, key: Any, value: Any):
if out:
self._mimic_inplace(out, inplace=True)

@acquire_spill_lock()
def transform(self, compiled_op, np_dtype: np.dtype) -> ColumnBase:
plc_column = plc.transform.transform(
self.to_pylibcudf(mode="read"),
compiled_op[0],
plc.column._datatype_from_dtype_desc(np_dtype.str[1:]),
True,
)
return type(self).from_pylibcudf(plc_column)

def unary_operator(self, unaryop: str | Callable) -> ColumnBase:
if callable(unaryop):
return libcudf.transform.transform(self, unaryop)
nb_type = numpy_support.from_dtype(self.dtype)
nb_signature = (nb_type,)
compiled_op = cudautils.compile_udf(unaryop, nb_signature)
np_dtype = np.dtype(compiled_op[1])
return self.transform(compiled_op, np_dtype)

unaryop = unaryop.upper()
unaryop = _unaryop_map.get(unaryop, unaryop)
unaryop = pylibcudf.unary.UnaryOperator[unaryop]
unaryop = plc.unary.UnaryOperator[unaryop]
return unary.unary_operation(self, unaryop)

def __invert__(self):
Expand Down Expand Up @@ -298,8 +315,11 @@ def nans_to_nulls(self: Self) -> Self:
# Only floats can contain nan.
if self.dtype.kind != "f" or self.nan_count == 0:
return self
newmask = libcudf.transform.nans_to_nulls(self)
return self.set_mask(newmask)
with acquire_spill_lock():
mask, _ = plc.transform.nans_to_nulls(
self.to_pylibcudf(mode="read")
)
return self.set_mask(as_buffer(mask))

def normalize_binop_value(self, other: ScalarLike) -> Self | cudf.Scalar:
if isinstance(other, ColumnBase):
Expand Down
30 changes: 14 additions & 16 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6772,9 +6772,7 @@ def _apply_cupy_method_axis_1(self, method, *args, **kwargs):
)
result = column.as_column(result, dtype=result_dtype)
if mask is not None:
result = result.set_mask(
cudf._lib.transform.bools_to_mask(mask._column)
)
result = result.set_mask(mask._column.as_mask())
return Series._from_column(result, index=self.index)
else:
result_df = DataFrame(result, index=self.index)
Expand Down Expand Up @@ -7883,6 +7881,16 @@ def interleave_columns(self):
)
return self._constructor_sliced._from_column(result_col)

@acquire_spill_lock()
def _compute_columns(self, expr: str) -> ColumnBase:
plc_column = plc.transform.compute_column(
plc.Table(
[col.to_pylibcudf(mode="read") for col in self._columns]
),
plc.expressions.to_expression(expr, self._column_names),
)
return libcudf.column.Column.from_pylibcudf(plc_column)

@_performance_tracking
def eval(self, expr: str, inplace: bool = False, **kwargs):
"""Evaluate a string describing operations on DataFrame columns.
Expand Down Expand Up @@ -8010,11 +8018,7 @@ def eval(self, expr: str, inplace: bool = False, **kwargs):
raise ValueError(
"Cannot operate inplace if there is no assignment"
)
return Series._from_column(
libcudf.transform.compute_column(
[*self._columns], self._column_names, statements[0]
)
)
return Series._from_column(self._compute_columns(statements[0]))

targets = []
exprs = []
Expand All @@ -8030,15 +8034,9 @@ def eval(self, expr: str, inplace: bool = False, **kwargs):
targets.append(t.strip())
exprs.append(e.strip())

cols = (
libcudf.transform.compute_column(
[*self._columns], self._column_names, e
)
for e in exprs
)
ret = self if inplace else self.copy(deep=False)
for name, col in zip(targets, cols):
ret._data[name] = col
for name, expr in zip(targets, exprs):
ret._data[name] = self._compute_columns(expr)
if not inplace:
return ret

Expand Down
3 changes: 1 addition & 2 deletions python/cudf/cudf/core/df_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,8 +799,7 @@ def _set_missing_values(
valid_mask = _ensure_gpu_buffer(
valid_mask[0], valid_mask[1], allow_copy
)
boolmask = as_column(valid_mask._buf, dtype="bool")
bitmask = cudf._lib.transform.bools_to_mask(boolmask)
bitmask = as_column(valid_mask._buf, dtype="bool").as_mask()
return cudf_col.set_mask(bitmask)
elif null == _MaskKind.BITMASK:
valid_mask = _ensure_gpu_buffer(
Expand Down
9 changes: 8 additions & 1 deletion python/cudf/cudf/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1457,7 +1457,14 @@ def _split(self, splits):

@_performance_tracking
def _encode(self):
columns, indices = libcudf.transform.table_encode(list(self._columns))
plc_table, plc_column = plc.transform.encode(
plc.Table([col.to_pylibcudf(mode="read") for col in self._columns])
)
columns = [
libcudf.column.Column.from_pylibcudf(col)
for col in plc_table.columns()
]
indices = libcudf.column.Column.from_pylibcudf(plc_column)
keys = self._from_columns_like_self(columns)
return keys, indices

Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/indexed_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3507,7 +3507,7 @@ def _apply(self, func, kernel_getter, *args, **kwargs):

col = _post_process_output_col(ans_col, retty)

col.set_base_mask(libcudf.transform.bools_to_mask(ans_mask))
col.set_base_mask(ans_mask.as_mask())
result = cudf.Series._from_column(col, index=self.index)

return result
Expand Down
Loading

0 comments on commit 9df95d1

Please sign in to comment.