diff --git a/python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx b/python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx index 559302ccc..9d1d24eae 100644 --- a/python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx +++ b/python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx @@ -31,9 +31,9 @@ from cuvs.distance_type cimport cuvsDistanceType from pylibraft.common import auto_convert_output, cai_wrapper, device_ndarray from pylibraft.common.cai_wrapper import wrap_array from pylibraft.common.interruptible import cuda_interruptible -from pylibraft.neighbors.common import _check_input_array from cuvs.distance import DISTANCE_TYPES +from cuvs.neighbors.common import _check_input_array from cuvs.common.c_api cimport cuvsResources_t diff --git a/python/cuvs/cuvs/neighbors/cagra/cagra.pyx b/python/cuvs/cuvs/neighbors/cagra/cagra.pyx index 95209dbeb..752aef741 100644 --- a/python/cuvs/cuvs/neighbors/cagra/cagra.pyx +++ b/python/cuvs/cuvs/neighbors/cagra/cagra.pyx @@ -32,7 +32,8 @@ from cuvs.common cimport cydlpack from pylibraft.common import auto_convert_output, cai_wrapper, device_ndarray from pylibraft.common.cai_wrapper import wrap_array from pylibraft.common.interruptible import cuda_interruptible -from pylibraft.neighbors.common import _check_input_array + +from cuvs.neighbors.common import _check_input_array from libc.stdint cimport ( int8_t, diff --git a/python/cuvs/cuvs/neighbors/common.py b/python/cuvs/cuvs/neighbors/common.py new file mode 100644 index 000000000..c14b9f8c9 --- /dev/null +++ b/python/cuvs/cuvs/neighbors/common.py @@ -0,0 +1,36 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def _check_input_array(cai, exp_dt, exp_rows=None, exp_cols=None): + if cai.dtype not in exp_dt: + raise TypeError("dtype %s not supported" % cai.dtype) + + if not cai.c_contiguous: + raise ValueError("Row major input is expected") + + if exp_cols is not None and cai.shape[1] != exp_cols: + raise ValueError( + "Incorrect number of columns, expected {} got {}".format( + exp_cols, cai.shape[1] + ) + ) + + if exp_rows is not None and cai.shape[0] != exp_rows: + raise ValueError( + "Incorrect number of rows, expected {} , got {}".format( + exp_rows, cai.shape[0] + ) + ) diff --git a/python/cuvs/cuvs/neighbors/filters/filters.pyx b/python/cuvs/cuvs/neighbors/filters/filters.pyx index 3a81cb786..9bc2a905c 100644 --- a/python/cuvs/cuvs/neighbors/filters/filters.pyx +++ b/python/cuvs/cuvs/neighbors/filters/filters.pyx @@ -20,11 +20,11 @@ import numpy as np from libc.stdint cimport uintptr_t from cuvs.common cimport cydlpack +from cuvs.neighbors.common import _check_input_array from .filters cimport BITMAP, NO_FILTER, cuvsFilter from pylibraft.common.cai_wrapper import wrap_array -from pylibraft.neighbors.common import _check_input_array cdef class Prefilter: diff --git a/python/cuvs/cuvs/neighbors/hnsw/hnsw.pyx b/python/cuvs/cuvs/neighbors/hnsw/hnsw.pyx index 018fcfef9..bcfaf167e 100644 --- a/python/cuvs/cuvs/neighbors/hnsw/hnsw.pyx +++ b/python/cuvs/cuvs/neighbors/hnsw/hnsw.pyx @@ -21,6 +21,7 @@ from libcpp.string cimport string from cuvs.common.exceptions import check_cuvs from cuvs.common.resources import auto_sync_resources +from cuvs.neighbors.common import _check_input_array from cuvs.common cimport cydlpack @@ -36,7 +37,6 @@ import uuid from pylibraft.common import auto_convert_output from pylibraft.common.cai_wrapper import wrap_array from pylibraft.common.interruptible import cuda_interruptible -from pylibraft.neighbors.common import _check_input_array cdef class SearchParams: diff --git a/python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pyx b/python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pyx index 25b9b2aee..7a169e1a0 100644 --- a/python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pyx +++ b/python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pyx @@ -31,9 +31,9 @@ from cuvs.distance_type cimport cuvsDistanceType from pylibraft.common import auto_convert_output, cai_wrapper, device_ndarray from pylibraft.common.cai_wrapper import wrap_array from pylibraft.common.interruptible import cuda_interruptible -from pylibraft.neighbors.common import _check_input_array from cuvs.distance import DISTANCE_TYPES +from cuvs.neighbors.common import _check_input_array from libc.stdint cimport ( int8_t, diff --git a/python/cuvs/cuvs/neighbors/ivf_pq/ivf_pq.pyx b/python/cuvs/cuvs/neighbors/ivf_pq/ivf_pq.pyx index 3add1df75..531302ee6 100644 --- a/python/cuvs/cuvs/neighbors/ivf_pq/ivf_pq.pyx +++ b/python/cuvs/cuvs/neighbors/ivf_pq/ivf_pq.pyx @@ -31,9 +31,9 @@ from cuvs.distance_type cimport cuvsDistanceType from pylibraft.common import auto_convert_output, cai_wrapper, device_ndarray from pylibraft.common.cai_wrapper import wrap_array from pylibraft.common.interruptible import cuda_interruptible -from pylibraft.neighbors.common import _check_input_array from cuvs.distance import DISTANCE_TYPES +from cuvs.neighbors.common import _check_input_array from libc.stdint cimport ( int8_t, diff --git a/python/cuvs/cuvs/neighbors/refine.pyx b/python/cuvs/cuvs/neighbors/refine.pyx index 0eccc4108..b7aa35dca 100644 --- a/python/cuvs/cuvs/neighbors/refine.pyx +++ b/python/cuvs/cuvs/neighbors/refine.pyx @@ -31,13 +31,13 @@ from cuvs.distance_type cimport cuvsDistanceType from pylibraft.common import auto_convert_output, device_ndarray from pylibraft.common.cai_wrapper import wrap_array from pylibraft.common.interruptible import cuda_interruptible -from pylibraft.neighbors.common import _check_input_array from cuvs.distance import DISTANCE_TYPES from cuvs.common.c_api cimport cuvsResources_t from cuvs.common.exceptions import check_cuvs +from cuvs.neighbors.common import _check_input_array @auto_sync_resources