From 849049e7a401db70777b917ae7eec7e709c6e704 Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Wed, 24 Apr 2024 08:59:51 +0000 Subject: [PATCH] NumPy hid sctypes further away but still use it for now The alternative is listing them all, but even that is a bit tedious for integers, because it needs something like: ``(np.byte, np.short, np.intc, getattr(np, "long", np.int_), np.longlong)`` --- python/cudf/cudf/core/index.py | 6 +++++- python/cudf/cudf/utils/dtypes.py | 8 ++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/python/cudf/cudf/core/index.py b/python/cudf/cudf/core/index.py index 6f08b1d83b3..e0797fad796 100644 --- a/python/cudf/cudf/core/index.py +++ b/python/cudf/cudf/core/index.py @@ -60,6 +60,7 @@ from cudf.core.single_column_frame import SingleColumnFrame from cudf.utils.docutils import copy_docstring from cudf.utils.dtypes import ( + _NUMPY_SCTYPES, _maybe_convert_to_default_type, find_common_type, is_mixed_with_object_dtype, @@ -341,7 +342,10 @@ def _data(self): @_cudf_nvtx_annotate def __contains__(self, item): if isinstance(item, bool) or not isinstance( - item, tuple(np.sctypes["int"] + np.sctypes["float"] + [int, float]) + item, + tuple( + _NUMPY_SCTYPES["int"] + _NUMPY_SCTYPES["float"] + [int, float] + ), ): return False try: diff --git a/python/cudf/cudf/utils/dtypes.py b/python/cudf/cudf/utils/dtypes.py index a33b5ca139c..2aa3129ab30 100644 --- a/python/cudf/cudf/utils/dtypes.py +++ b/python/cudf/cudf/utils/dtypes.py @@ -91,6 +91,10 @@ BOOL_TYPES = {"bool"} ALL_TYPES = NUMERIC_TYPES | DATETIME_TYPES | TIMEDELTA_TYPES | OTHER_TYPES +# The NumPy scalar types are a bit of a mess as they align with the C types +# so for now we use the `sctypes` dict (although it was made private in 2.0) +_NUMPY_SCTYPES = np.sctypes if hasattr(np, "sctypes") else np._core.sctypes + def np_to_pa_dtype(dtype): """Util to convert numpy dtype to PyArrow dtype.""" @@ -335,7 +339,7 @@ def min_signed_type(x, min_size=8): Return the smallest *signed* integer dtype that can represent the integer ``x`` """ - for int_dtype in np.sctypes["int"]: + for int_dtype in _NUMPY_SCTYPES["int"]: if (cudf.dtype(int_dtype).itemsize * 8) >= min_size: if np.iinfo(int_dtype).min <= x <= np.iinfo(int_dtype).max: return int_dtype @@ -348,7 +352,7 @@ def min_unsigned_type(x, min_size=8): Return the smallest *unsigned* integer dtype that can represent the integer ``x`` """ - for int_dtype in np.sctypes["uint"]: + for int_dtype in _NUMPY_SCTYPES["uint"]: if (cudf.dtype(int_dtype).itemsize * 8) >= min_size: if 0 <= x <= np.iinfo(int_dtype).max: return int_dtype