Skip to content

Commit

Permalink
NumPy hid sctypes further away but still use it for now
Browse files Browse the repository at this point in the history
The alternative is listing them all, but even that is a bit tedious
for integers, because it needs something like:

``(np.byte, np.short, np.intc, getattr(np, "long", np.int_), np.longlong)``
  • Loading branch information
seberg committed Apr 24, 2024
1 parent 323db6c commit 8aaa435
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
7 changes: 5 additions & 2 deletions python/cudf/cudf/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from cudf.core.single_column_frame import SingleColumnFrame
from cudf.utils.docutils import copy_docstring
from cudf.utils.dtypes import (
_NUMPY_SCTYPES,
_maybe_convert_to_default_type,
find_common_type,
is_mixed_with_object_dtype,
Expand Down Expand Up @@ -341,8 +342,10 @@ def _data(self):
@_cudf_nvtx_annotate
def __contains__(self, item):
if isinstance(item, bool) or not isinstance(
item, tuple(np.sctypes["int"] + np.sctypes["float"] + [int, float])
):
item,
tuple(
_NUMPY_SCTYPES["int"] + _NUMPY_SCTYPES["float"] + [int, float])
):
return False
try:
int_item = int(item)
Expand Down
8 changes: 6 additions & 2 deletions python/cudf/cudf/utils/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@
BOOL_TYPES = {"bool"}
ALL_TYPES = NUMERIC_TYPES | DATETIME_TYPES | TIMEDELTA_TYPES | OTHER_TYPES

# The NumPy scalar types are a bit of a mess as they align with the C types
# so for now we use the `sctypes` dict (although it was made private in 2.0)
_NUMPY_SCTYPES = np.sctypes if hasattr(np, "sctypes") else np._core.sctypes


def np_to_pa_dtype(dtype):
"""Util to convert numpy dtype to PyArrow dtype."""
Expand Down Expand Up @@ -335,7 +339,7 @@ def min_signed_type(x, min_size=8):
Return the smallest *signed* integer dtype
that can represent the integer ``x``
"""
for int_dtype in np.sctypes["int"]:
for int_dtype in _NUMPY_SCTYPES["int"]:
if (cudf.dtype(int_dtype).itemsize * 8) >= min_size:
if np.iinfo(int_dtype).min <= x <= np.iinfo(int_dtype).max:
return int_dtype
Expand All @@ -348,7 +352,7 @@ def min_unsigned_type(x, min_size=8):
Return the smallest *unsigned* integer dtype
that can represent the integer ``x``
"""
for int_dtype in np.sctypes["uint"]:
for int_dtype in _NUMPY_SCTYPES["uint"]:
if (cudf.dtype(int_dtype).itemsize * 8) >= min_size:
if 0 <= x <= np.iinfo(int_dtype).max:
return int_dtype
Expand Down

0 comments on commit 8aaa435

Please sign in to comment.