diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 81b82aa7b..32093776c 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -369,6 +369,7 @@ if(BUILD_SHARED_LIBS) src/distance/detail/fused_distance_nn.cu src/distance/distance.cu src/distance/pairwise_distance.cu + src/distance/sparse_distance.cu src/neighbors/brute_force.cu src/neighbors/cagra_build_float.cu src/neighbors/cagra_build_half.cu @@ -449,6 +450,7 @@ if(BUILD_SHARED_LIBS) src/neighbors/refine/detail/refine_host_int8_t_float.cpp src/neighbors/refine/detail/refine_host_uint8_t_float.cpp src/neighbors/sample_filter.cu + src/neighbors/sparse_brute_force.cu src/neighbors/vamana_build_float.cu src/neighbors/vamana_build_uint8.cu src/neighbors/vamana_build_int8.cu diff --git a/cpp/include/cuvs/distance/distance.hpp b/cpp/include/cuvs/distance/distance.hpp index def72641e..42c574e58 100644 --- a/cpp/include/cuvs/distance/distance.hpp +++ b/cpp/include/cuvs/distance/distance.hpp @@ -20,6 +20,7 @@ #include #include +#include #include #include @@ -331,6 +332,86 @@ void pairwise_distance( cuvs::distance::DistanceType metric, float metric_arg = 2.0f); +/** + * @brief Compute sparse pairwise distances between x and y, using the provided + * input configuration and distance function. + * + * @code{.cpp} + * #include + * #include + * #include + * + * int x_n_rows = 100000; + * int y_n_rows = 50000; + * int n_cols = 10000; + * + * raft::device_resources handle; + * auto x = raft::make_device_csr_matrix(handle, x_n_rows, n_cols); + * auto y = raft::make_device_csr_matrix(handle, y_n_rows, n_cols); + * + * ... + * // populate data + * ... + * + * auto out = raft::make_device_matrix(handle, x_nrows, y_nrows); + * auto metric = cuvs::distance::DistanceType::L2Expanded; + * raft::sparse::distance::pairwise_distance(handle, x.view(), y.view(), out, metric); + * @endcode + * + * @param[in] handle raft::resources + * @param[in] x raft::device_csr_matrix_view + * @param[in] y raft::device_csr_matrix_view + * @param[out] dist raft::device_matrix_view dense matrix + * @param[in] metric distance metric to use + * @param[in] metric_arg metric argument (used for Minkowski distance) + */ +void pairwise_distance(raft::resources const& handle, + raft::device_csr_matrix_view x, + raft::device_csr_matrix_view y, + raft::device_matrix_view dist, + cuvs::distance::DistanceType metric, + float metric_arg = 2.0f); + +/** + * @brief Compute sparse pairwise distances between x and y, using the provided + * input configuration and distance function. + * + * @code{.cpp} + * #include + * #include + * #include + * + * int x_n_rows = 100000; + * int y_n_rows = 50000; + * int n_cols = 10000; + * + * raft::device_resources handle; + * auto x = raft::make_device_csr_matrix(handle, x_n_rows, n_cols); + * auto y = raft::make_device_csr_matrix(handle, y_n_rows, n_cols); + * + * ... + * // populate data + * ... + * + * auto out = raft::make_device_matrix(handle, x_nrows, y_nrows); + * auto metric = cuvs::distance::DistanceType::L2Expanded; + * raft::sparse::distance::pairwise_distance(handle, x.view(), y.view(), out, metric); + * @endcode + * + * @param[in] handle raft::resources + * @param[in] x raft::device_csr_matrix_view + * @param[in] y raft::device_csr_matrix_view + * @param[out] dist raft::device_matrix_view dense matrix + * @param[in] metric distance metric to use + * @param[in] metric_arg metric argument (used for Minkowski distance) + */ +void pairwise_distance(raft::resources const& handle, + raft::device_csr_matrix_view x, + raft::device_csr_matrix_view y, + raft::device_matrix_view dist, + cuvs::distance::DistanceType metric, + float metric_arg = 2.0f); + /** @} */ // end group pairwise_distance_runtime }; // namespace cuvs::distance diff --git a/cpp/include/cuvs/neighbors/brute_force.hpp b/cpp/include/cuvs/neighbors/brute_force.hpp index 428fa592a..ba67797ee 100644 --- a/cpp/include/cuvs/neighbors/brute_force.hpp +++ b/cpp/include/cuvs/neighbors/brute_force.hpp @@ -18,6 +18,7 @@ #include "common.hpp" #include +#include #include #include #include @@ -375,4 +376,107 @@ void search(raft::resources const& handle, * @} */ +/** + * @defgroup sparse_bruteforce_cpp_index Sparse Brute Force index + * @{ + */ +/** + * @brief Sparse Brute Force index. + * + * @tparam T Data element type + * @tparam IdxT Index element type + */ +template +struct sparse_index { + public: + sparse_index(const sparse_index&) = delete; + sparse_index(sparse_index&&) = default; + sparse_index& operator=(const sparse_index&) = delete; + sparse_index& operator=(sparse_index&&) = default; + ~sparse_index() = default; + + /** Construct a sparse brute force sparse_index from dataset */ + sparse_index(raft::resources const& res, + raft::device_csr_matrix_view dataset, + cuvs::distance::DistanceType metric, + T metric_arg); + + /** Distance metric used for retrieval */ + cuvs::distance::DistanceType metric() const noexcept { return metric_; } + + /** Metric argument */ + T metric_arg() const noexcept { return metric_arg_; } + + raft::device_csr_matrix_view dataset() const noexcept + { + return dataset_; + } + + private: + raft::device_csr_matrix_view dataset_; + cuvs::distance::DistanceType metric_; + T metric_arg_; +}; +/** + * @} + */ + +/** + * @defgroup sparse_bruteforce_cpp_index_build Sparse Brute Force index build + * @{ + */ + +/* + * @brief Build the Sparse index from the dataset + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // create and fill the index from a CSR dataset + * auto index = brute_force::build(handle, dataset, metric); + * @endcode + * + * @param[in] handle + * @param[in] dataset A sparse CSR matrix in device memory to search against + * @param[in] metric cuvs::distance::DistanceType + * @param[in] metric_arg metric argument + * + * @return the constructed Sparse brute-force index + */ +auto build(raft::resources const& handle, + raft::device_csr_matrix_view dataset, + cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded, + float metric_arg = 0) -> cuvs::neighbors::brute_force::sparse_index; +/** + * @} + */ + +/** + * @defgroup sparse_bruteforce_cpp_index_search Sparse Brute Force index search + * @{ + */ +struct sparse_search_params { + int batch_size_index = 2 << 14; + int batch_size_query = 2 << 14; +}; + +/* + * @brief Search the sparse bruteforce index for nearest neighbors + * + * @param[in] handle + * @param[in] index Sparse brute-force constructed index + * @param[in] queries a sparse CSR matrix on the device to query + * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] + */ +void search(raft::resources const& handle, + const sparse_search_params& params, + const sparse_index& index, + raft::device_csr_matrix_view dataset, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances); +/** + * @} + */ } // namespace cuvs::neighbors::brute_force diff --git a/cpp/src/distance/detail/sparse/bin_distance.cuh b/cpp/src/distance/detail/sparse/bin_distance.cuh new file mode 100644 index 000000000..1a63a8eb9 --- /dev/null +++ b/cpp/src/distance/detail/sparse/bin_distance.cuh @@ -0,0 +1,231 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "common.hpp" +#include "ip_distance.cuh" + +#include +#include +#include +#include + +#include + +#include + +#include + +namespace cuvs { +namespace distance { +namespace detail { +namespace sparse { +// @TODO: Move this into sparse prims (coo_norm) +template +RAFT_KERNEL compute_binary_row_norm_kernel(value_t* out, + const value_idx* __restrict__ coo_rows, + const value_t* __restrict__ data, + value_idx nnz) +{ + value_idx i = blockDim.x * blockIdx.x + threadIdx.x; + if (i < nnz) { + // We do conditional here only because it's + // possible there could be some stray zeros in + // the sparse structure and removing them would be + // more expensive. + atomicAdd(&out[coo_rows[i]], data[i] == 1.0); + } +} + +template +RAFT_KERNEL compute_binary_warp_kernel(value_t* __restrict__ C, + const value_t* __restrict__ Q_norms, + const value_t* __restrict__ R_norms, + value_idx n_rows, + value_idx n_cols, + expansion_f expansion_func) +{ + std::size_t tid = blockDim.x * blockIdx.x + threadIdx.x; + value_idx i = tid / n_cols; + value_idx j = tid % n_cols; + + if (i >= n_rows || j >= n_cols) return; + + value_t q_norm = Q_norms[i]; + value_t r_norm = R_norms[j]; + value_t dot = C[(size_t)i * n_cols + j]; + C[(size_t)i * n_cols + j] = expansion_func(dot, q_norm, r_norm); +} + +template +void compute_binary(value_t* C, + const value_t* Q_norms, + const value_t* R_norms, + value_idx n_rows, + value_idx n_cols, + expansion_f expansion_func, + cudaStream_t stream) +{ + int blocks = raft::ceildiv((size_t)n_rows * n_cols, tpb); + compute_binary_warp_kernel<<>>( + C, Q_norms, R_norms, n_rows, n_cols, expansion_func); +} + +template +void compute_bin_distance(value_t* out, + const value_idx* Q_coo_rows, + const value_t* Q_data, + value_idx Q_nnz, + const value_idx* R_coo_rows, + const value_t* R_data, + value_idx R_nnz, + value_idx m, + value_idx n, + cudaStream_t stream, + expansion_f expansion_func) +{ + rmm::device_uvector Q_norms(m, stream); + rmm::device_uvector R_norms(n, stream); + RAFT_CUDA_TRY(cudaMemsetAsync(Q_norms.data(), 0, Q_norms.size() * sizeof(value_t))); + RAFT_CUDA_TRY(cudaMemsetAsync(R_norms.data(), 0, R_norms.size() * sizeof(value_t))); + + compute_binary_row_norm_kernel<<>>( + Q_norms.data(), Q_coo_rows, Q_data, Q_nnz); + compute_binary_row_norm_kernel<<>>( + R_norms.data(), R_coo_rows, R_data, R_nnz); + + compute_binary(out, Q_norms.data(), R_norms.data(), m, n, expansion_func, stream); +} + +/** + * Jaccard distance using the expanded form: + * 1 - (sum(x_k * y_k) / ((sum(x_k) + sum(y_k)) - sum(x_k * y_k)) + */ +template +class jaccard_expanded_distances_t : public distances_t { + public: + explicit jaccard_expanded_distances_t(const distances_config_t& config) + : config_(&config), + workspace(0, raft::resource::get_cuda_stream(config.handle)), + ip_dists(config) + { + } + + void compute(value_t* out_dists) + { + ip_dists.compute(out_dists); + + value_idx* b_indices = ip_dists.b_rows_coo(); + value_t* b_data = ip_dists.b_data_coo(); + + rmm::device_uvector search_coo_rows( + config_->a_nnz, raft::resource::get_cuda_stream(config_->handle)); + raft::sparse::convert::csr_to_coo(config_->a_indptr, + config_->a_nrows, + search_coo_rows.data(), + config_->a_nnz, + raft::resource::get_cuda_stream(config_->handle)); + + compute_bin_distance(out_dists, + search_coo_rows.data(), + config_->a_data, + config_->a_nnz, + b_indices, + b_data, + config_->b_nnz, + config_->a_nrows, + config_->b_nrows, + raft::resource::get_cuda_stream(config_->handle), + [] __device__ __host__(value_t dot, value_t q_norm, value_t r_norm) { + value_t q_r_union = q_norm + r_norm; + value_t denom = q_r_union - dot; + + value_t jacc = ((denom != 0) * dot) / ((denom == 0) + denom); + + // flip the similarity when both rows are 0 + bool both_empty = q_r_union == 0; + return 1 - ((!both_empty * jacc) + both_empty); + }); + } + + ~jaccard_expanded_distances_t() = default; + + private: + const distances_config_t* config_; + rmm::device_uvector workspace; + ip_distances_t ip_dists; +}; + +/** + * Dice distance using the expanded form: + * 1 - ((2 * sum(x_k * y_k)) / (sum(x_k) + sum(y_k))) + */ +template +class dice_expanded_distances_t : public distances_t { + public: + explicit dice_expanded_distances_t(const distances_config_t& config) + : config_(&config), + workspace(0, raft::resource::get_cuda_stream(config.handle)), + ip_dists(config) + { + } + + void compute(value_t* out_dists) + { + ip_dists.compute(out_dists); + + value_idx* b_indices = ip_dists.b_rows_coo(); + value_t* b_data = ip_dists.b_data_coo(); + + rmm::device_uvector search_coo_rows( + config_->a_nnz, raft::resource::get_cuda_stream(config_->handle)); + raft::sparse::convert::csr_to_coo(config_->a_indptr, + config_->a_nrows, + search_coo_rows.data(), + config_->a_nnz, + raft::resource::get_cuda_stream(config_->handle)); + + compute_bin_distance(out_dists, + search_coo_rows.data(), + config_->a_data, + config_->a_nnz, + b_indices, + b_data, + config_->b_nnz, + config_->a_nrows, + config_->b_nrows, + raft::resource::get_cuda_stream(config_->handle), + [] __device__ __host__(value_t dot, value_t q_norm, value_t r_norm) { + value_t q_r_union = q_norm + r_norm; + value_t dice = (2 * dot) / q_r_union; + bool both_empty = q_r_union == 0; + return 1 - ((!both_empty * dice) + both_empty); + }); + } + + ~dice_expanded_distances_t() = default; + + private: + const distances_config_t* config_; + rmm::device_uvector workspace; + ip_distances_t ip_dists; +}; + +} // END namespace sparse +} // END namespace detail +} // END namespace distance +} // END namespace cuvs diff --git a/cpp/src/distance/detail/sparse/common.hpp b/cpp/src/distance/detail/sparse/common.hpp new file mode 100644 index 000000000..803dabe56 --- /dev/null +++ b/cpp/src/distance/detail/sparse/common.hpp @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace cuvs { +namespace distance { +namespace detail { +namespace sparse { + +template +struct distances_config_t { + distances_config_t(raft::resources const& handle_) : handle(handle_) {} + + // left side + value_idx a_nrows; + value_idx a_ncols; + value_idx a_nnz; + value_idx* a_indptr; + value_idx* a_indices; + value_t* a_data; + + // right side + value_idx b_nrows; + value_idx b_ncols; + value_idx b_nnz; + value_idx* b_indptr; + value_idx* b_indices; + value_t* b_data; + + raft::resources const& handle; +}; + +template +class distances_t { + public: + virtual void compute(value_t* out) {} + virtual ~distances_t() = default; +}; + +} // namespace sparse +} // namespace detail +} // namespace distance +} // namespace cuvs diff --git a/cpp/src/distance/detail/sparse/coo_spmv.cuh b/cpp/src/distance/detail/sparse/coo_spmv.cuh new file mode 100644 index 000000000..181b531f7 --- /dev/null +++ b/cpp/src/distance/detail/sparse/coo_spmv.cuh @@ -0,0 +1,211 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "common.hpp" +#include "coo_spmv_strategies/dense_smem_strategy.cuh" +#include "coo_spmv_strategies/hash_strategy.cuh" + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +namespace cuvs { +namespace distance { +namespace detail { +namespace sparse { + +template +inline void balanced_coo_pairwise_generalized_spmv( + value_t* out_dists, + const distances_config_t& config_, + value_idx* coo_rows_b, + product_f product_func, + accum_f accum_func, + write_f write_func, + strategy_t strategy, + int chunk_size = 500000) +{ + uint64_t n = (uint64_t)sizeof(value_t) * (uint64_t)config_.a_nrows * (uint64_t)config_.b_nrows; + RAFT_CUDA_TRY(cudaMemsetAsync(out_dists, 0, n, raft::resource::get_cuda_stream(config_.handle))); + + strategy.dispatch(out_dists, coo_rows_b, product_func, accum_func, write_func, chunk_size); +}; + +/** + * Performs generalized sparse-matrix-sparse-matrix multiplication via a + * sparse-matrix-sparse-vector layout `out=A*B` where generalized product() + * and sum() operations can be used in place of the standard sum and product: + * + * out_ij = sum_k(product(A_ik, B_ik)) The sum goes through values of + * k=0..n_cols-1 where B_kj is nonzero. + * + * The product and sum operations shall form a semiring algebra with the + * following properties: + * 1. {+, 0} is a commutative sum reduction monoid with identity element 0 + * 2. {*, 1} is a product monoid with identity element 1 + * 3. Multiplication by 0 annihilates x. e.g. product(x, 0) = 0 + * + * Each vector of A is loaded into shared memory in dense form and the + * non-zeros of B load balanced across the threads of each block. + * @tparam value_idx index type + * @tparam value_t value type + * @tparam threads_per_block block size + * @tparam product_f semiring product() function + * @tparam accum_f semiring sum() function + * @tparam write_f atomic semiring sum() function + * @param[out] out_dists dense array of out distances of size m * n in row-major + * format. + * @param[in] config_ distance config object + * @param[in] coo_rows_b coo row array for B + * @param[in] product_func semiring product() function + * @param[in] accum_func semiring sum() function + * @param[in] write_func atomic semiring sum() function + * @param[in] chunk_size number of nonzeros of B to process for each row of A + * this value was found through profiling and represents a reasonable + * setting for both large and small densities + */ +template +inline void balanced_coo_pairwise_generalized_spmv( + value_t* out_dists, + const distances_config_t& config_, + value_idx* coo_rows_b, + product_f product_func, + accum_f accum_func, + write_f write_func, + int chunk_size = 500000) +{ + uint64_t n = (uint64_t)sizeof(value_t) * (uint64_t)config_.a_nrows * (uint64_t)config_.b_nrows; + RAFT_CUDA_TRY(cudaMemsetAsync(out_dists, 0, n, raft::resource::get_cuda_stream(config_.handle))); + + int max_cols = max_cols_per_block(); + + if (max_cols > config_.a_ncols) { + dense_smem_strategy strategy(config_); + strategy.dispatch(out_dists, coo_rows_b, product_func, accum_func, write_func, chunk_size); + } else { + hash_strategy strategy(config_); + strategy.dispatch(out_dists, coo_rows_b, product_func, accum_func, write_func, chunk_size); + } +}; + +template +inline void balanced_coo_pairwise_generalized_spmv_rev( + value_t* out_dists, + const distances_config_t& config_, + value_idx* coo_rows_a, + product_f product_func, + accum_f accum_func, + write_f write_func, + strategy_t strategy, + int chunk_size = 500000) +{ + strategy.dispatch_rev(out_dists, coo_rows_a, product_func, accum_func, write_func, chunk_size); +}; + +/** + * Used for computing distances where the reduction (e.g. product()) function + * requires an implicit union (product(x, 0) = x) to capture the difference A-B. + * This is necessary in some applications because the standard semiring algebra + * endowed with the default multiplication product monoid will only + * compute the intersection & B-A. + * + * This particular function is meant to accompany the function + * `balanced_coo_pairwise_generalized_spmv` and executes the product operation + * on only those columns that exist in B and not A. + * + * The product and sum operations shall enable the computation of a + * non-annihilating semiring algebra with the following properties: + * 1. {+, 0} is a commutative sum reduction monoid with identity element 0 + * 2. {*, 0} is a product monoid with identity element 0 + * 3. Multiplication by 0 does not annihilate x. e.g. product(x, 0) = x + * + * Manattan distance sum(abs(x_k-y_k)) is a great example of when this type of + * execution pattern is necessary. + * + * @tparam value_idx index type + * @tparam value_t value type + * @tparam threads_per_block block size + * @tparam product_f semiring product() function + * @tparam accum_f semiring sum() function + * @tparam write_f atomic semiring sum() function + * @param[out] out_dists dense array of out distances of size m * n + * @param[in] config_ distance config object + * @param[in] coo_rows_a coo row array for A + * @param[in] product_func semiring product() function + * @param[in] accum_func semiring sum() function + * @param[in] write_func atomic semiring sum() function + * @param[in] chunk_size number of nonzeros of B to process for each row of A + * this value was found through profiling and represents a reasonable + * setting for both large and small densities + */ +template +inline void balanced_coo_pairwise_generalized_spmv_rev( + value_t* out_dists, + const distances_config_t& config_, + value_idx* coo_rows_a, + product_f product_func, + accum_f accum_func, + write_f write_func, + int chunk_size = 500000) +{ + // try dense first + int max_cols = max_cols_per_block(); + + if (max_cols > config_.b_ncols) { + dense_smem_strategy strategy(config_); + strategy.dispatch_rev(out_dists, coo_rows_a, product_func, accum_func, write_func, chunk_size); + } else { + hash_strategy strategy(config_); + strategy.dispatch_rev(out_dists, coo_rows_a, product_func, accum_func, write_func, chunk_size); + } +}; + +} // namespace sparse +} // namespace detail +} // namespace distance +} // namespace cuvs diff --git a/cpp/src/distance/detail/sparse/coo_spmv_kernel.cuh b/cpp/src/distance/detail/sparse/coo_spmv_kernel.cuh new file mode 100644 index 000000000..1f4b19af4 --- /dev/null +++ b/cpp/src/distance/detail/sparse/coo_spmv_kernel.cuh @@ -0,0 +1,229 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace cuvs { +namespace distance { +namespace detail { +namespace sparse { +__device__ __inline__ unsigned int get_lowest_peer(unsigned int peer_group) +{ + return __ffs(peer_group) - 1; +} + +/** + * Load-balanced sparse-matrix-sparse-matrix multiplication (SPMM) kernel with + * sparse-matrix-sparse-vector multiplication layout (SPMV). + * This is intended to be scheduled n_chunks_b times for each row of a. + * The steps are as follows: + * + * 1. Load row from A into dense vector in shared memory. + * This can be further chunked in the future if necessary to support larger + * column sizes. + * 2. Threads of block all step through chunks of B in parallel. + * When a new row is encountered in row_indices_b, a segmented + * reduction is performed across the warps and then across the + * block and the final value written out to host memory. + * + * Reference: https://www.icl.utk.edu/files/publications/2020/icl-utk-1421-2020.pdf + * + * @tparam value_idx index type + * @tparam value_t value type + * @tparam tpb threads per block configured on launch + * @tparam rev if this is true, the reduce/accumulate functions are only + * executed when A[col] == 0.0. when executed before/after !rev + * and A & B are reversed, this allows the full symmetric difference + * and intersection to be computed. + * @tparam kv_t data type stored in shared mem cache + * @tparam product_f reduce function type (semiring product() function). + * accepts two arguments of value_t and returns a value_t + * @tparam accum_f accumulation function type (semiring sum() function). + * accepts two arguments of value_t and returns a value_t + * @tparam write_f function to write value out. this should be mathematically + * equivalent to the accumulate function but implemented as + * an atomic operation on global memory. Accepts two arguments + * of value_t* and value_t and updates the value given by the + * pointer. + * @param[in] indptrA column pointer array for A + * @param[in] indicesA column indices array for A + * @param[in] dataA data array for A + * @param[in] rowsB coo row array for B + * @param[in] indicesB column indices array for B + * @param[in] dataB data array for B + * @param[in] m number of rows in A + * @param[in] n number of rows in B + * @param[in] dim number of features + * @param[in] nnz_b number of nonzeros in B + * @param[out] out array of size m*n + * @param[in] n_blocks_per_row number of blocks of B per row of A + * @param[in] chunk_size number of nnz for B to use for each row of A + * @param[in] buffer_size amount of smem to use for each row of A + * @param[in] product_func semiring product() function + * @param[in] accum_func semiring sum() function + * @param[in] write_func atomic semiring sum() function + */ +template +RAFT_KERNEL balanced_coo_generalized_spmv_kernel(strategy_t strategy, + indptr_it indptrA, + value_idx* indicesA, + value_t* dataA, + value_idx nnz_a, + value_idx* rowsB, + value_idx* indicesB, + value_t* dataB, + value_idx m, + value_idx n, + int dim, + value_idx nnz_b, + value_t* out, + int n_blocks_per_row, + int chunk_size, + value_idx b_ncols, + product_f product_func, + accum_f accum_func, + write_f write_func) +{ + typedef cub::WarpReduce warp_reduce; + + value_idx cur_row_a = indptrA.get_row_idx(n_blocks_per_row); + value_idx cur_chunk_offset = blockIdx.x % n_blocks_per_row; + + // chunk starting offset + value_idx ind_offset = cur_chunk_offset * chunk_size * tpb; + // how many total cols will be processed by this block (should be <= chunk_size * n_threads) + value_idx active_chunk_size = min(chunk_size * tpb, nnz_b - ind_offset); + + int tid = threadIdx.x; + int warp_id = tid / raft::warp_size(); + + // compute id relative to current warp + unsigned int lane_id = tid & (raft::warp_size() - 1); + value_idx ind = ind_offset + threadIdx.x; + + extern __shared__ char smem[]; + + typename strategy_t::smem_type A = (typename strategy_t::smem_type)(smem); + typename warp_reduce::TempStorage* temp_storage = (typename warp_reduce::TempStorage*)(A + dim); + + auto inserter = strategy.init_insert(A, dim); + + __syncthreads(); + + value_idx start_offset_a, stop_offset_a; + bool first_a_chunk, last_a_chunk; + indptrA.get_row_offsets( + cur_row_a, start_offset_a, stop_offset_a, n_blocks_per_row, first_a_chunk, last_a_chunk); + + // Convert current row vector in A to dense + for (int i = tid; i <= (stop_offset_a - start_offset_a); i += blockDim.x) { + strategy.insert(inserter, indicesA[start_offset_a + i], dataA[start_offset_a + i]); + } + + __syncthreads(); + + auto finder = strategy.init_find(A, dim); + + if (cur_row_a > m || cur_chunk_offset > n_blocks_per_row) return; + if (ind >= nnz_b) return; + + value_idx start_index_a = 0, stop_index_a = b_ncols - 1; + indptrA.get_indices_boundary(indicesA, + cur_row_a, + start_offset_a, + stop_offset_a, + start_index_a, + stop_index_a, + first_a_chunk, + last_a_chunk); + + value_idx cur_row_b = -1; + value_t c = 0.0; + + auto warp_red = warp_reduce(*(temp_storage + warp_id)); + + if (tid < active_chunk_size) { + cur_row_b = rowsB[ind]; + + auto index_b = indicesB[ind]; + auto in_bounds = indptrA.check_indices_bounds(start_index_a, stop_index_a, index_b); + + if (in_bounds) { + value_t a_col = strategy.find(finder, index_b); + if (!rev || a_col == 0.0) { c = product_func(a_col, dataB[ind]); } + } + } + + // loop through chunks in parallel, reducing when a new row is + // encountered by each thread + for (int i = tid; i < active_chunk_size; i += blockDim.x) { + value_idx ind_next = ind + blockDim.x; + value_idx next_row_b = -1; + + if (i + blockDim.x < active_chunk_size) next_row_b = rowsB[ind_next]; + + bool diff_rows = next_row_b != cur_row_b; + + if (__any_sync(0xffffffff, diff_rows)) { + // grab the threads currently participating in loops. + // because any other threads should have returned already. + unsigned int peer_group = __match_any_sync(0xffffffff, cur_row_b); + bool is_leader = get_lowest_peer(peer_group) == lane_id; + value_t v = warp_red.HeadSegmentedReduce(c, is_leader, accum_func); + + // thread with lowest lane id among peers writes out + if (is_leader && v != 0.0) { + // this conditional should be uniform, since rev is constant + size_t idx = !rev ? (size_t)cur_row_a * n + cur_row_b : (size_t)cur_row_b * m + cur_row_a; + write_func(out + idx, v); + } + + c = 0.0; + } + + if (next_row_b != -1) { + ind = ind_next; + + auto index_b = indicesB[ind]; + auto in_bounds = indptrA.check_indices_bounds(start_index_a, stop_index_a, index_b); + if (in_bounds) { + value_t a_col = strategy.find(finder, index_b); + + if (!rev || a_col == 0.0) { c = accum_func(c, product_func(a_col, dataB[ind])); } + } + + cur_row_b = next_row_b; + } + } +} + +} // namespace sparse +} // namespace detail +} // namespace distance +} // namespace cuvs diff --git a/cpp/src/distance/detail/sparse/coo_spmv_strategies/base_strategy.cuh b/cpp/src/distance/detail/sparse/coo_spmv_strategies/base_strategy.cuh new file mode 100644 index 000000000..457b25eea --- /dev/null +++ b/cpp/src/distance/detail/sparse/coo_spmv_strategies/base_strategy.cuh @@ -0,0 +1,149 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../common.hpp" +#include "../coo_spmv_kernel.cuh" +#include "../utils.cuh" +#include "coo_mask_row_iterators.cuh" + +#include + +#include + +namespace cuvs { +namespace distance { +namespace detail { +namespace sparse { + +template +class coo_spmv_strategy { + public: + coo_spmv_strategy(const distances_config_t& config_) : config(config_) + { + smem = raft::getSharedMemPerBlock(); + } + + template + void _dispatch_base(strategy_t& strategy, + int smem_dim, + indptr_it& a_indptr, + value_t* out_dists, + value_idx* coo_rows_b, + product_f product_func, + accum_f accum_func, + write_f write_func, + int chunk_size, + int n_blocks, + int n_blocks_per_row) + { + RAFT_CUDA_TRY(cudaFuncSetCacheConfig(balanced_coo_generalized_spmv_kernel, + cudaFuncCachePreferShared)); + + balanced_coo_generalized_spmv_kernel + <<>>(strategy, + a_indptr, + config.a_indices, + config.a_data, + config.a_nnz, + coo_rows_b, + config.b_indices, + config.b_data, + config.a_nrows, + config.b_nrows, + smem_dim, + config.b_nnz, + out_dists, + n_blocks_per_row, + chunk_size, + config.b_ncols, + product_func, + accum_func, + write_func); + } + + template + void _dispatch_base_rev(strategy_t& strategy, + int smem_dim, + indptr_it& b_indptr, + value_t* out_dists, + value_idx* coo_rows_a, + product_f product_func, + accum_f accum_func, + write_f write_func, + int chunk_size, + int n_blocks, + int n_blocks_per_row) + { + RAFT_CUDA_TRY(cudaFuncSetCacheConfig(balanced_coo_generalized_spmv_kernel, + cudaFuncCachePreferShared)); + + balanced_coo_generalized_spmv_kernel + <<>>(strategy, + b_indptr, + config.b_indices, + config.b_data, + config.b_nnz, + coo_rows_a, + config.a_indices, + config.a_data, + config.b_nrows, + config.a_nrows, + smem_dim, + config.a_nnz, + out_dists, + n_blocks_per_row, + chunk_size, + config.a_ncols, + product_func, + accum_func, + write_func); + } + + protected: + int smem; + const distances_config_t& config; +}; + +} // namespace sparse +} // namespace detail +} // namespace distance +} // namespace cuvs diff --git a/cpp/src/distance/detail/sparse/coo_spmv_strategies/coo_mask_row_iterators.cuh b/cpp/src/distance/detail/sparse/coo_spmv_strategies/coo_mask_row_iterators.cuh new file mode 100644 index 000000000..a9040e1d8 --- /dev/null +++ b/cpp/src/distance/detail/sparse/coo_spmv_strategies/coo_mask_row_iterators.cuh @@ -0,0 +1,234 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../common.hpp" +#include "../utils.cuh" + +#include // raft::ceildiv + +#include + +#include +#include + +namespace cuvs { +namespace distance { +namespace detail { +namespace sparse { + +template +class mask_row_it { + public: + mask_row_it(const value_idx* full_indptr_, + const value_idx& n_rows_, + value_idx* mask_row_idx_ = NULL) + : full_indptr(full_indptr_), mask_row_idx(mask_row_idx_), n_rows(n_rows_) + { + } + + __device__ inline value_idx get_row_idx(const int& n_blocks_nnz_b) + { + if (mask_row_idx != NULL) { + return mask_row_idx[blockIdx.x / n_blocks_nnz_b]; + } else { + return blockIdx.x / n_blocks_nnz_b; + } + } + + __device__ inline void get_row_offsets(const value_idx& row_idx, + value_idx& start_offset, + value_idx& stop_offset, + const value_idx& n_blocks_nnz_b, + bool& first_a_chunk, + bool& last_a_chunk) + { + start_offset = full_indptr[row_idx]; + stop_offset = full_indptr[row_idx + 1] - 1; + } + + __device__ constexpr inline void get_indices_boundary(const value_idx* indices, + value_idx& indices_len, + value_idx& start_offset, + value_idx& stop_offset, + value_idx& start_index, + value_idx& stop_index, + bool& first_a_chunk, + bool& last_a_chunk) + { + // do nothing; + } + + __device__ constexpr inline bool check_indices_bounds(value_idx& start_index_a, + value_idx& stop_index_a, + value_idx& index_b) + { + return true; + } + + const value_idx *full_indptr, &n_rows; + value_idx* mask_row_idx; +}; + +template +RAFT_KERNEL fill_chunk_indices_kernel(value_idx* n_chunks_per_row, + value_idx* chunk_indices, + value_idx n_rows) +{ + auto tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid < n_rows) { + auto start = n_chunks_per_row[tid]; + auto end = n_chunks_per_row[tid + 1]; + +#pragma unroll + for (int i = start; i < end; i++) { + chunk_indices[i] = tid; + } + } +} + +template +class chunked_mask_row_it : public mask_row_it { + public: + chunked_mask_row_it(const value_idx* full_indptr_, + const value_idx& n_rows_, + value_idx* mask_row_idx_, + int row_chunk_size_, + const value_idx* n_chunks_per_row_, + const value_idx* chunk_indices_, + const cudaStream_t stream_) + : mask_row_it(full_indptr_, n_rows_, mask_row_idx_), + row_chunk_size(row_chunk_size_), + n_chunks_per_row(n_chunks_per_row_), + chunk_indices(chunk_indices_), + stream(stream_) + { + } + + static void init(const value_idx* indptr, + const value_idx* mask_row_idx, + const value_idx& n_rows, + const int row_chunk_size, + rmm::device_uvector& n_chunks_per_row, + rmm::device_uvector& chunk_indices, + cudaStream_t stream) + { + auto policy = rmm::exec_policy(stream); + + constexpr value_idx first_element = 0; + n_chunks_per_row.set_element_async(0, first_element, stream); + n_chunks_per_row_functor chunk_functor(indptr, row_chunk_size); + thrust::transform( + policy, mask_row_idx, mask_row_idx + n_rows, n_chunks_per_row.begin() + 1, chunk_functor); + + thrust::inclusive_scan( + policy, n_chunks_per_row.begin() + 1, n_chunks_per_row.end(), n_chunks_per_row.begin() + 1); + + raft::update_host(&total_row_blocks, n_chunks_per_row.data() + n_rows, 1, stream); + + fill_chunk_indices(n_rows, n_chunks_per_row, chunk_indices, stream); + } + + __device__ inline value_idx get_row_idx(const int& n_blocks_nnz_b) + { + return this->mask_row_idx[chunk_indices[blockIdx.x / n_blocks_nnz_b]]; + } + + __device__ inline void get_row_offsets(const value_idx& row_idx, + value_idx& start_offset, + value_idx& stop_offset, + const int& n_blocks_nnz_b, + bool& first_a_chunk, + bool& last_a_chunk) + { + auto chunk_index = blockIdx.x / n_blocks_nnz_b; + auto chunk_val = chunk_indices[chunk_index]; + auto prev_n_chunks = n_chunks_per_row[chunk_val]; + auto relative_chunk = chunk_index - prev_n_chunks; + first_a_chunk = relative_chunk == 0; + + start_offset = this->full_indptr[row_idx] + relative_chunk * row_chunk_size; + stop_offset = start_offset + row_chunk_size; + + auto final_stop_offset = this->full_indptr[row_idx + 1]; + + last_a_chunk = stop_offset >= final_stop_offset; + stop_offset = last_a_chunk ? final_stop_offset - 1 : stop_offset - 1; + } + + __device__ inline void get_indices_boundary(const value_idx* indices, + value_idx& row_idx, + value_idx& start_offset, + value_idx& stop_offset, + value_idx& start_index, + value_idx& stop_index, + bool& first_a_chunk, + bool& last_a_chunk) + { + start_index = first_a_chunk ? start_index : indices[start_offset - 1] + 1; + stop_index = last_a_chunk ? stop_index : indices[stop_offset]; + } + + __device__ inline bool check_indices_bounds(value_idx& start_index_a, + value_idx& stop_index_a, + value_idx& index_b) + { + return (index_b >= start_index_a && index_b <= stop_index_a); + } + + inline static value_idx total_row_blocks = 0; + const cudaStream_t stream; + const value_idx *n_chunks_per_row, *chunk_indices; + value_idx row_chunk_size; + + struct n_chunks_per_row_functor { + public: + n_chunks_per_row_functor(const value_idx* indptr_, value_idx row_chunk_size_) + : indptr(indptr_), row_chunk_size(row_chunk_size_) + { + } + + __host__ __device__ value_idx operator()(const value_idx& i) + { + auto degree = indptr[i + 1] - indptr[i]; + return raft::ceildiv(degree, (value_idx)row_chunk_size); + } + + const value_idx* indptr; + value_idx row_chunk_size; + }; + + private: + static void fill_chunk_indices(const value_idx& n_rows, + rmm::device_uvector& n_chunks_per_row, + rmm::device_uvector& chunk_indices, + cudaStream_t stream) + { + auto n_threads = std::min(n_rows, 256); + auto n_blocks = raft::ceildiv(n_rows, (value_idx)n_threads); + + chunk_indices.resize(total_row_blocks, stream); + + fill_chunk_indices_kernel + <<>>(n_chunks_per_row.data(), chunk_indices.data(), n_rows); + } +}; + +} // namespace sparse +} // namespace detail +} // namespace distance +} // namespace cuvs diff --git a/cpp/src/distance/detail/sparse/coo_spmv_strategies/dense_smem_strategy.cuh b/cpp/src/distance/detail/sparse/coo_spmv_strategies/dense_smem_strategy.cuh new file mode 100644 index 000000000..baa913a6c --- /dev/null +++ b/cpp/src/distance/detail/sparse/coo_spmv_strategies/dense_smem_strategy.cuh @@ -0,0 +1,121 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "base_strategy.cuh" + +#include // raft::ceildiv + +namespace cuvs { +namespace distance { +namespace detail { +namespace sparse { + +template +class dense_smem_strategy : public coo_spmv_strategy { + public: + using smem_type = value_t*; + using insert_type = smem_type; + using find_type = smem_type; + + dense_smem_strategy(const distances_config_t& config_) + : coo_spmv_strategy(config_) + { + } + + inline static int smem_per_block(int n_cols) + { + return (n_cols * sizeof(value_t)) + ((1024 / raft::warp_size()) * sizeof(value_t)); + } + + template + void dispatch(value_t* out_dists, + value_idx* coo_rows_b, + product_f product_func, + accum_f accum_func, + write_f write_func, + int chunk_size) + { + auto n_blocks_per_row = raft::ceildiv(this->config.b_nnz, chunk_size * 1024); + auto n_blocks = this->config.a_nrows * n_blocks_per_row; + + mask_row_it a_indptr(this->config.a_indptr, this->config.a_nrows); + + this->_dispatch_base(*this, + this->config.b_ncols, + a_indptr, + out_dists, + coo_rows_b, + product_func, + accum_func, + write_func, + chunk_size, + n_blocks, + n_blocks_per_row); + } + + template + void dispatch_rev(value_t* out_dists, + value_idx* coo_rows_a, + product_f product_func, + accum_f accum_func, + write_f write_func, + int chunk_size) + { + auto n_blocks_per_row = raft::ceildiv(this->config.a_nnz, chunk_size * 1024); + auto n_blocks = this->config.b_nrows * n_blocks_per_row; + + mask_row_it b_indptr(this->config.b_indptr, this->config.b_nrows); + + this->_dispatch_base_rev(*this, + this->config.a_ncols, + b_indptr, + out_dists, + coo_rows_a, + product_func, + accum_func, + write_func, + chunk_size, + n_blocks, + n_blocks_per_row); + } + + __device__ inline insert_type init_insert(smem_type cache, const value_idx& cache_size) + { + for (int k = threadIdx.x; k < cache_size; k += blockDim.x) { + cache[k] = 0.0; + } + return cache; + } + + __device__ inline void insert(insert_type cache, const value_idx& key, const value_t& value) + { + cache[key] = value; + } + + __device__ inline find_type init_find(smem_type cache, const value_idx& cache_size) + { + return cache; + } + + __device__ inline value_t find(find_type cache, const value_idx& key) { return cache[key]; } +}; + +} // namespace sparse +} // namespace detail +} // namespace distance +} // namespace cuvs diff --git a/cpp/src/distance/detail/sparse/coo_spmv_strategies/hash_strategy.cuh b/cpp/src/distance/detail/sparse/coo_spmv_strategies/hash_strategy.cuh new file mode 100644 index 000000000..cf212076b --- /dev/null +++ b/cpp/src/distance/detail/sparse/coo_spmv_strategies/hash_strategy.cuh @@ -0,0 +1,296 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "base_strategy.cuh" + +#include +#include + +#include +#include +#include + +// this is needed by cuco as key, value must be bitwise comparable. +// compilers don't declare float/double as bitwise comparable +// but that is too strict +// for example, the following is true (or 0): +// float a = 5; +// float b = 5; +// memcmp(&a, &b, sizeof(float)); +CUCO_DECLARE_BITWISE_COMPARABLE(float); +CUCO_DECLARE_BITWISE_COMPARABLE(double); + +namespace cuvs { +namespace distance { +namespace detail { +namespace sparse { + +template +class hash_strategy : public coo_spmv_strategy { + public: + using insert_type = typename cuco::legacy:: + static_map::device_mutable_view; + using smem_type = typename insert_type::slot_type*; + using find_type = + typename cuco::legacy::static_map::device_view; + + hash_strategy(const distances_config_t& config_, + float capacity_threshold_ = 0.5, + int map_size_ = get_map_size()) + : coo_spmv_strategy(config_), + capacity_threshold(capacity_threshold_), + map_size(map_size_) + { + } + + void chunking_needed(const value_idx* indptr, + const value_idx n_rows, + rmm::device_uvector& mask_indptr, + std::tuple& n_rows_divided, + cudaStream_t stream) + { + auto policy = raft::resource::get_thrust_policy(this->config.handle); + + auto less = thrust::copy_if(policy, + thrust::make_counting_iterator(value_idx(0)), + thrust::make_counting_iterator(n_rows), + mask_indptr.data(), + fits_in_hash_table(indptr, 0, capacity_threshold * map_size)); + std::get<0>(n_rows_divided) = less - mask_indptr.data(); + + auto more = thrust::copy_if( + policy, + thrust::make_counting_iterator(value_idx(0)), + thrust::make_counting_iterator(n_rows), + less, + fits_in_hash_table( + indptr, capacity_threshold * map_size, std::numeric_limits::max())); + std::get<1>(n_rows_divided) = more - less; + } + + template + void dispatch(value_t* out_dists, + value_idx* coo_rows_b, + product_f product_func, + accum_f accum_func, + write_f write_func, + int chunk_size) + { + auto n_blocks_per_row = raft::ceildiv(this->config.b_nnz, chunk_size * tpb); + rmm::device_uvector mask_indptr( + this->config.a_nrows, raft::resource::get_cuda_stream(this->config.handle)); + std::tuple n_rows_divided; + + chunking_needed(this->config.a_indptr, + this->config.a_nrows, + mask_indptr, + n_rows_divided, + raft::resource::get_cuda_stream(this->config.handle)); + + auto less_rows = std::get<0>(n_rows_divided); + if (less_rows > 0) { + mask_row_it less(this->config.a_indptr, less_rows, mask_indptr.data()); + + auto n_less_blocks = less_rows * n_blocks_per_row; + this->_dispatch_base(*this, + map_size, + less, + out_dists, + coo_rows_b, + product_func, + accum_func, + write_func, + chunk_size, + n_less_blocks, + n_blocks_per_row); + } + + auto more_rows = std::get<1>(n_rows_divided); + if (more_rows > 0) { + rmm::device_uvector n_chunks_per_row( + more_rows + 1, raft::resource::get_cuda_stream(this->config.handle)); + rmm::device_uvector chunk_indices( + 0, raft::resource::get_cuda_stream(this->config.handle)); + chunked_mask_row_it::init(this->config.a_indptr, + mask_indptr.data() + less_rows, + more_rows, + capacity_threshold * map_size, + n_chunks_per_row, + chunk_indices, + raft::resource::get_cuda_stream(this->config.handle)); + + chunked_mask_row_it more(this->config.a_indptr, + more_rows, + mask_indptr.data() + less_rows, + capacity_threshold * map_size, + n_chunks_per_row.data(), + chunk_indices.data(), + raft::resource::get_cuda_stream(this->config.handle)); + + auto n_more_blocks = more.total_row_blocks * n_blocks_per_row; + this->_dispatch_base(*this, + map_size, + more, + out_dists, + coo_rows_b, + product_func, + accum_func, + write_func, + chunk_size, + n_more_blocks, + n_blocks_per_row); + } + } + + template + void dispatch_rev(value_t* out_dists, + value_idx* coo_rows_a, + product_f product_func, + accum_f accum_func, + write_f write_func, + int chunk_size) + { + auto n_blocks_per_row = raft::ceildiv(this->config.a_nnz, chunk_size * tpb); + rmm::device_uvector mask_indptr( + this->config.b_nrows, raft::resource::get_cuda_stream(this->config.handle)); + std::tuple n_rows_divided; + + chunking_needed(this->config.b_indptr, + this->config.b_nrows, + mask_indptr, + n_rows_divided, + raft::resource::get_cuda_stream(this->config.handle)); + + auto less_rows = std::get<0>(n_rows_divided); + if (less_rows > 0) { + mask_row_it less(this->config.b_indptr, less_rows, mask_indptr.data()); + + auto n_less_blocks = less_rows * n_blocks_per_row; + this->_dispatch_base_rev(*this, + map_size, + less, + out_dists, + coo_rows_a, + product_func, + accum_func, + write_func, + chunk_size, + n_less_blocks, + n_blocks_per_row); + } + + auto more_rows = std::get<1>(n_rows_divided); + if (more_rows > 0) { + rmm::device_uvector n_chunks_per_row( + more_rows + 1, raft::resource::get_cuda_stream(this->config.handle)); + rmm::device_uvector chunk_indices( + 0, raft::resource::get_cuda_stream(this->config.handle)); + chunked_mask_row_it::init(this->config.b_indptr, + mask_indptr.data() + less_rows, + more_rows, + capacity_threshold * map_size, + n_chunks_per_row, + chunk_indices, + raft::resource::get_cuda_stream(this->config.handle)); + + chunked_mask_row_it more(this->config.b_indptr, + more_rows, + mask_indptr.data() + less_rows, + capacity_threshold * map_size, + n_chunks_per_row.data(), + chunk_indices.data(), + raft::resource::get_cuda_stream(this->config.handle)); + + auto n_more_blocks = more.total_row_blocks * n_blocks_per_row; + this->_dispatch_base_rev(*this, + map_size, + more, + out_dists, + coo_rows_a, + product_func, + accum_func, + write_func, + chunk_size, + n_more_blocks, + n_blocks_per_row); + } + } + + __device__ inline insert_type init_insert(smem_type cache, const value_idx& cache_size) + { + return insert_type::make_from_uninitialized_slots(cooperative_groups::this_thread_block(), + cache, + cache_size, + cuco::empty_key{value_idx{-1}}, + cuco::empty_value{value_t{0}}); + } + + __device__ inline void insert(insert_type cache, const value_idx& key, const value_t& value) + { + auto success = cache.insert(cuco::pair(key, value)); + } + + __device__ inline find_type init_find(smem_type cache, const value_idx& cache_size) + { + return find_type( + cache, cache_size, cuco::empty_key{value_idx{-1}}, cuco::empty_value{value_t{0}}); + } + + __device__ inline value_t find(find_type cache, const value_idx& key) + { + auto a_pair = cache.find(key); + + value_t a_col = 0.0; + if (a_pair != cache.end()) { a_col = a_pair->second; } + return a_col; + } + + struct fits_in_hash_table { + public: + fits_in_hash_table(const value_idx* indptr_, value_idx degree_l_, value_idx degree_r_) + : indptr(indptr_), degree_l(degree_l_), degree_r(degree_r_) + { + } + + __host__ __device__ bool operator()(const value_idx& i) + { + auto degree = indptr[i + 1] - indptr[i]; + + return degree >= degree_l && degree < degree_r; + } + + private: + const value_idx* indptr; + const value_idx degree_l, degree_r; + }; + + inline static int get_map_size() + { + return (raft::getSharedMemPerBlock() - ((tpb / raft::warp_size()) * sizeof(value_t))) / + sizeof(typename insert_type::slot_type); + } + + private: + float capacity_threshold; + int map_size; +}; + +} // namespace sparse +} // namespace detail +} // namespace distance +} // namespace cuvs diff --git a/cpp/src/distance/detail/sparse/ip_distance.cuh b/cpp/src/distance/detail/sparse/ip_distance.cuh new file mode 100644 index 000000000..3a11d4e99 --- /dev/null +++ b/cpp/src/distance/detail/sparse/ip_distance.cuh @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "common.hpp" +#include "coo_spmv.cuh" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include + +namespace cuvs { +namespace distance { +namespace detail { +namespace sparse { + +template +class ip_distances_t : public distances_t { + public: + /** + * Computes simple sparse inner product distances as sum(x_y * y_k) + * @param[in] config specifies inputs, outputs, and sizes + */ + ip_distances_t(const distances_config_t& config) + : config_(&config), coo_rows_b(config.b_nnz, raft::resource::get_cuda_stream(config.handle)) + { + raft::sparse::convert::csr_to_coo(config_->b_indptr, + config_->b_nrows, + coo_rows_b.data(), + config_->b_nnz, + raft::resource::get_cuda_stream(config_->handle)); + } + + /** + * Performs pairwise distance computation and computes output distances + * @param out_distances dense output matrix (size a_nrows * b_nrows) + */ + void compute(value_t* out_distances) + { + /** + * Compute pairwise distances and return dense matrix in row-major format + */ + balanced_coo_pairwise_generalized_spmv(out_distances, + *config_, + coo_rows_b.data(), + raft::mul_op(), + raft::add_op(), + raft::atomic_add_op()); + } + + value_idx* b_rows_coo() { return coo_rows_b.data(); } + + value_t* b_data_coo() { return config_->b_data; } + + private: + const distances_config_t* config_; + rmm::device_uvector coo_rows_b; +}; + +} // END namespace sparse +} // END namespace detail +} // END namespace distance +} // END namespace cuvs diff --git a/cpp/src/distance/detail/sparse/l2_distance.cuh b/cpp/src/distance/detail/sparse/l2_distance.cuh new file mode 100644 index 000000000..40e7070fc --- /dev/null +++ b/cpp/src/distance/detail/sparse/l2_distance.cuh @@ -0,0 +1,502 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "common.hpp" +#include "ip_distance.cuh" +#include + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +#include +#include + +namespace cuvs { +namespace distance { +namespace detail { +namespace sparse { + +// @TODO: Move this into sparse prims (coo_norm) +template +RAFT_KERNEL compute_row_norm_kernel(value_t* out, + const value_idx* __restrict__ coo_rows, + const value_t* __restrict__ data, + value_idx nnz) +{ + value_idx i = blockDim.x * blockIdx.x + threadIdx.x; + if (i < nnz) { atomicAdd(&out[coo_rows[i]], data[i] * data[i]); } +} + +template +RAFT_KERNEL compute_row_sum_kernel(value_t* out, + const value_idx* __restrict__ coo_rows, + const value_t* __restrict__ data, + value_idx nnz) +{ + value_idx i = blockDim.x * blockIdx.x + threadIdx.x; + if (i < nnz) { atomicAdd(&out[coo_rows[i]], data[i]); } +} + +template +RAFT_KERNEL compute_euclidean_warp_kernel(value_t* __restrict__ C, + const value_t* __restrict__ Q_sq_norms, + const value_t* __restrict__ R_sq_norms, + value_idx n_rows, + value_idx n_cols, + expansion_f expansion_func) +{ + std::size_t tid = blockDim.x * blockIdx.x + threadIdx.x; + value_idx i = tid / n_cols; + value_idx j = tid % n_cols; + + if (i >= n_rows || j >= n_cols) return; + + value_t dot = C[(size_t)i * n_cols + j]; + + // e.g. Euclidean expansion func = -2.0 * dot + q_norm + r_norm + value_t val = expansion_func(dot, Q_sq_norms[i], R_sq_norms[j]); + + // correct for small instabilities + C[(size_t)i * n_cols + j] = val * (fabs(val) >= 0.0001); +} + +template +RAFT_KERNEL compute_correlation_warp_kernel(value_t* __restrict__ C, + const value_t* __restrict__ Q_sq_norms, + const value_t* __restrict__ R_sq_norms, + const value_t* __restrict__ Q_norms, + const value_t* __restrict__ R_norms, + value_idx n_rows, + value_idx n_cols, + value_idx n) +{ + std::size_t tid = blockDim.x * blockIdx.x + threadIdx.x; + value_idx i = tid / n_cols; + value_idx j = tid % n_cols; + + if (i >= n_rows || j >= n_cols) return; + + value_t dot = C[(size_t)i * n_cols + j]; + value_t Q_l1 = Q_norms[i]; + value_t R_l1 = R_norms[j]; + + value_t Q_l2 = Q_sq_norms[i]; + value_t R_l2 = R_sq_norms[j]; + + value_t numer = n * dot - (Q_l1 * R_l1); + value_t Q_denom = n * Q_l2 - (Q_l1 * Q_l1); + value_t R_denom = n * R_l2 - (R_l1 * R_l1); + + value_t val = 1 - (numer / raft::sqrt(Q_denom * R_denom)); + + // correct for small instabilities + C[(size_t)i * n_cols + j] = val * (fabs(val) >= 0.0001); +} + +template +void compute_euclidean(value_t* C, + const value_t* Q_sq_norms, + const value_t* R_sq_norms, + value_idx n_rows, + value_idx n_cols, + cudaStream_t stream, + expansion_f expansion_func) +{ + int blocks = raft::ceildiv((size_t)n_rows * n_cols, tpb); + compute_euclidean_warp_kernel<<>>( + C, Q_sq_norms, R_sq_norms, n_rows, n_cols, expansion_func); +} + +template +void compute_l2(value_t* out, + const value_idx* Q_coo_rows, + const value_t* Q_data, + value_idx Q_nnz, + const value_idx* R_coo_rows, + const value_t* R_data, + value_idx R_nnz, + value_idx m, + value_idx n, + cudaStream_t stream, + expansion_f expansion_func) +{ + rmm::device_uvector Q_sq_norms(m, stream); + rmm::device_uvector R_sq_norms(n, stream); + RAFT_CUDA_TRY(cudaMemsetAsync(Q_sq_norms.data(), 0, Q_sq_norms.size() * sizeof(value_t))); + RAFT_CUDA_TRY(cudaMemsetAsync(R_sq_norms.data(), 0, R_sq_norms.size() * sizeof(value_t))); + + compute_row_norm_kernel<<>>( + Q_sq_norms.data(), Q_coo_rows, Q_data, Q_nnz); + compute_row_norm_kernel<<>>( + R_sq_norms.data(), R_coo_rows, R_data, R_nnz); + + compute_euclidean(out, Q_sq_norms.data(), R_sq_norms.data(), m, n, stream, expansion_func); +} + +template +void compute_correlation(value_t* C, + const value_t* Q_sq_norms, + const value_t* R_sq_norms, + const value_t* Q_norms, + const value_t* R_norms, + value_idx n_rows, + value_idx n_cols, + value_idx n, + cudaStream_t stream) +{ + int blocks = raft::ceildiv((size_t)n_rows * n_cols, tpb); + compute_correlation_warp_kernel<<>>( + C, Q_sq_norms, R_sq_norms, Q_norms, R_norms, n_rows, n_cols, n); +} + +template +void compute_corr(value_t* out, + const value_idx* Q_coo_rows, + const value_t* Q_data, + value_idx Q_nnz, + const value_idx* R_coo_rows, + const value_t* R_data, + value_idx R_nnz, + value_idx m, + value_idx n, + value_idx n_cols, + cudaStream_t stream) +{ + // sum_sq for std dev + rmm::device_uvector Q_sq_norms(m, stream); + rmm::device_uvector R_sq_norms(n, stream); + + // sum for mean + rmm::device_uvector Q_norms(m, stream); + rmm::device_uvector R_norms(n, stream); + + RAFT_CUDA_TRY(cudaMemsetAsync(Q_sq_norms.data(), 0, Q_sq_norms.size() * sizeof(value_t))); + RAFT_CUDA_TRY(cudaMemsetAsync(R_sq_norms.data(), 0, R_sq_norms.size() * sizeof(value_t))); + + RAFT_CUDA_TRY(cudaMemsetAsync(Q_norms.data(), 0, Q_norms.size() * sizeof(value_t))); + RAFT_CUDA_TRY(cudaMemsetAsync(R_norms.data(), 0, R_norms.size() * sizeof(value_t))); + + compute_row_norm_kernel<<>>( + Q_sq_norms.data(), Q_coo_rows, Q_data, Q_nnz); + compute_row_norm_kernel<<>>( + R_sq_norms.data(), R_coo_rows, R_data, R_nnz); + + compute_row_sum_kernel<<>>( + Q_norms.data(), Q_coo_rows, Q_data, Q_nnz); + compute_row_sum_kernel<<>>( + R_norms.data(), R_coo_rows, R_data, R_nnz); + + compute_correlation(out, + Q_sq_norms.data(), + R_sq_norms.data(), + Q_norms.data(), + R_norms.data(), + m, + n, + n_cols, + stream); +} + +/** + * L2 distance using the expanded form: sum(x_k)^2 + sum(y_k)^2 - 2 * sum(x_k * y_k) + * The expanded form is more efficient for sparse data. + */ +template +class l2_expanded_distances_t : public distances_t { + public: + explicit l2_expanded_distances_t(const distances_config_t& config) + : config_(&config), ip_dists(config) + { + } + + void compute(value_t* out_dists) + { + ip_dists.compute(out_dists); + + value_idx* b_indices = ip_dists.b_rows_coo(); + value_t* b_data = ip_dists.b_data_coo(); + + rmm::device_uvector search_coo_rows( + config_->a_nnz, raft::resource::get_cuda_stream(config_->handle)); + raft::sparse::convert::csr_to_coo(config_->a_indptr, + config_->a_nrows, + search_coo_rows.data(), + config_->a_nnz, + raft::resource::get_cuda_stream(config_->handle)); + + compute_l2(out_dists, + search_coo_rows.data(), + config_->a_data, + config_->a_nnz, + b_indices, + b_data, + config_->b_nnz, + config_->a_nrows, + config_->b_nrows, + raft::resource::get_cuda_stream(config_->handle), + [] __device__ __host__(value_t dot, value_t q_norm, value_t r_norm) { + return -2 * dot + q_norm + r_norm; + }); + } + + ~l2_expanded_distances_t() = default; + + protected: + const distances_config_t* config_; + ip_distances_t ip_dists; +}; + +/** + * L2 sqrt distance performing the sqrt operation after the distance computation + * The expanded form is more efficient for sparse data. + */ +template +class l2_sqrt_expanded_distances_t : public l2_expanded_distances_t { + public: + explicit l2_sqrt_expanded_distances_t(const distances_config_t& config) + : l2_expanded_distances_t(config) + { + } + + void compute(value_t* out_dists) override + { + l2_expanded_distances_t::compute(out_dists); + // Sqrt Post-processing + raft::linalg::unaryOp( + out_dists, + out_dists, + this->config_->a_nrows * this->config_->b_nrows, + [] __device__(value_t input) { + int neg = input < 0 ? -1 : 1; + return raft::sqrt(abs(input) * neg); + }, + raft::resource::get_cuda_stream(this->config_->handle)); + } + + ~l2_sqrt_expanded_distances_t() = default; +}; + +template +class correlation_expanded_distances_t : public distances_t { + public: + explicit correlation_expanded_distances_t(const distances_config_t& config) + : config_(&config), ip_dists(config) + { + } + + void compute(value_t* out_dists) + { + ip_dists.compute(out_dists); + + value_idx* b_indices = ip_dists.b_rows_coo(); + value_t* b_data = ip_dists.b_data_coo(); + + rmm::device_uvector search_coo_rows( + config_->a_nnz, raft::resource::get_cuda_stream(config_->handle)); + raft::sparse::convert::csr_to_coo(config_->a_indptr, + config_->a_nrows, + search_coo_rows.data(), + config_->a_nnz, + raft::resource::get_cuda_stream(config_->handle)); + + compute_corr(out_dists, + search_coo_rows.data(), + config_->a_data, + config_->a_nnz, + b_indices, + b_data, + config_->b_nnz, + config_->a_nrows, + config_->b_nrows, + config_->b_ncols, + raft::resource::get_cuda_stream(config_->handle)); + } + + ~correlation_expanded_distances_t() = default; + + protected: + const distances_config_t* config_; + ip_distances_t ip_dists; +}; + +/** + * Cosine distance using the expanded form: 1 - ( sum(x_k * y_k) / (sqrt(sum(x_k)^2) * + * sqrt(sum(y_k)^2))) The expanded form is more efficient for sparse data. + */ +template +class cosine_expanded_distances_t : public distances_t { + public: + explicit cosine_expanded_distances_t(const distances_config_t& config) + : config_(&config), + workspace(0, raft::resource::get_cuda_stream(config.handle)), + ip_dists(config) + { + } + + void compute(value_t* out_dists) + { + ip_dists.compute(out_dists); + + value_idx* b_indices = ip_dists.b_rows_coo(); + value_t* b_data = ip_dists.b_data_coo(); + + rmm::device_uvector search_coo_rows( + config_->a_nnz, raft::resource::get_cuda_stream(config_->handle)); + raft::sparse::convert::csr_to_coo(config_->a_indptr, + config_->a_nrows, + search_coo_rows.data(), + config_->a_nnz, + raft::resource::get_cuda_stream(config_->handle)); + + compute_l2(out_dists, + search_coo_rows.data(), + config_->a_data, + config_->a_nnz, + b_indices, + b_data, + config_->b_nnz, + config_->a_nrows, + config_->b_nrows, + raft::resource::get_cuda_stream(config_->handle), + [] __device__ __host__(value_t dot, value_t q_norm, value_t r_norm) { + value_t norms = raft::sqrt(q_norm) * raft::sqrt(r_norm); + // deal with potential for 0 in denominator by forcing 0/1 instead + value_t cos = ((norms != 0) * dot) / ((norms == 0) + norms); + + // flip the similarity when both rows are 0 + bool both_empty = (q_norm == 0) && (r_norm == 0); + return 1 - ((!both_empty * cos) + both_empty); + }); + } + + ~cosine_expanded_distances_t() = default; + + private: + const distances_config_t* config_; + rmm::device_uvector workspace; + ip_distances_t ip_dists; +}; + +/** + * Hellinger distance using the expanded form: sqrt(1 - sum(sqrt(x_k) * sqrt(y_k))) + * The expanded form is more efficient for sparse data. + * + * This distance computation modifies A and B by computing a sqrt + * and then performing a `pow(x, 2)` to convert it back. Because of this, + * it is possible that the values in A and B might differ slightly + * after this is invoked. + */ +template +class hellinger_expanded_distances_t : public distances_t { + public: + explicit hellinger_expanded_distances_t(const distances_config_t& config) + : config_(&config), workspace(0, raft::resource::get_cuda_stream(config.handle)) + { + } + + void compute(value_t* out_dists) + { + rmm::device_uvector coo_rows(std::max(config_->b_nnz, config_->a_nnz), + raft::resource::get_cuda_stream(config_->handle)); + + raft::sparse::convert::csr_to_coo(config_->b_indptr, + config_->b_nrows, + coo_rows.data(), + config_->b_nnz, + raft::resource::get_cuda_stream(config_->handle)); + + balanced_coo_pairwise_generalized_spmv( + out_dists, + *config_, + coo_rows.data(), + [] __device__(value_t a, value_t b) { return raft::sqrt(a) * raft::sqrt(b); }, + raft::add_op(), + raft::atomic_add_op()); + + raft::linalg::unaryOp( + out_dists, + out_dists, + config_->a_nrows * config_->b_nrows, + [=] __device__(value_t input) { + // Adjust to replace NaN in sqrt with 0 if input to sqrt is negative + bool rectifier = (1 - input) > 0; + return raft::sqrt(rectifier * (1 - input)); + }, + raft::resource::get_cuda_stream(config_->handle)); + } + + ~hellinger_expanded_distances_t() = default; + + private: + const distances_config_t* config_; + rmm::device_uvector workspace; +}; + +template +class russelrao_expanded_distances_t : public distances_t { + public: + explicit russelrao_expanded_distances_t(const distances_config_t& config) + : config_(&config), + workspace(0, raft::resource::get_cuda_stream(config.handle)), + ip_dists(config) + { + } + + void compute(value_t* out_dists) + { + ip_dists.compute(out_dists); + + value_t n_cols = config_->a_ncols; + value_t n_cols_inv = 1.0 / n_cols; + raft::linalg::unaryOp( + out_dists, + out_dists, + config_->a_nrows * config_->b_nrows, + [=] __device__(value_t input) { return (n_cols - input) * n_cols_inv; }, + raft::resource::get_cuda_stream(config_->handle)); + + auto exec_policy = rmm::exec_policy(raft::resource::get_cuda_stream(config_->handle)); + auto diags = thrust::counting_iterator(0); + value_idx b_nrows = config_->b_nrows; + thrust::for_each(exec_policy, diags, diags + config_->a_nrows, [=] __device__(value_idx input) { + out_dists[input * b_nrows + input] = 0.0; + }); + } + + ~russelrao_expanded_distances_t() = default; + + private: + const distances_config_t* config_; + rmm::device_uvector workspace; + ip_distances_t ip_dists; +}; + +} // END namespace sparse +} // END namespace detail +} // END namespace distance +} // END namespace cuvs diff --git a/cpp/src/distance/detail/sparse/lp_distance.cuh b/cpp/src/distance/detail/sparse/lp_distance.cuh new file mode 100644 index 000000000..18e7b04e4 --- /dev/null +++ b/cpp/src/distance/detail/sparse/lp_distance.cuh @@ -0,0 +1,333 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "common.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include +#include + +namespace cuvs { +namespace distance { +namespace detail { +namespace sparse { + +template +void unexpanded_lp_distances(value_t* out_dists, + const distances_config_t* config_, + product_f product_func, + accum_f accum_func, + write_f write_func) +{ + rmm::device_uvector coo_rows(std::max(config_->b_nnz, config_->a_nnz), + raft::resource::get_cuda_stream(config_->handle)); + + raft::sparse::convert::csr_to_coo(config_->b_indptr, + config_->b_nrows, + coo_rows.data(), + config_->b_nnz, + raft::resource::get_cuda_stream(config_->handle)); + + balanced_coo_pairwise_generalized_spmv( + out_dists, *config_, coo_rows.data(), product_func, accum_func, write_func); + + raft::sparse::convert::csr_to_coo(config_->a_indptr, + config_->a_nrows, + coo_rows.data(), + config_->a_nnz, + raft::resource::get_cuda_stream(config_->handle)); + + balanced_coo_pairwise_generalized_spmv_rev( + out_dists, *config_, coo_rows.data(), product_func, accum_func, write_func); +} + +/** + * Computes L1 distances for sparse input. This does not have + * an equivalent expanded form, so it is only executed in + * an unexpanded form. + * @tparam value_idx + * @tparam value_t + */ +template +class l1_unexpanded_distances_t : public distances_t { + public: + l1_unexpanded_distances_t(const distances_config_t& config) : config_(&config) + { + } + + void compute(value_t* out_dists) + { + unexpanded_lp_distances( + out_dists, config_, raft::absdiff_op(), raft::add_op(), raft::atomic_add_op()); + } + + private: + const distances_config_t* config_; +}; + +template +class l2_unexpanded_distances_t : public distances_t { + public: + l2_unexpanded_distances_t(const distances_config_t& config) : config_(&config) + { + } + + void compute(value_t* out_dists) + { + unexpanded_lp_distances( + out_dists, config_, raft::sqdiff_op(), raft::add_op(), raft::atomic_add_op()); + } + + protected: + const distances_config_t* config_; +}; + +template +class l2_sqrt_unexpanded_distances_t : public l2_unexpanded_distances_t { + public: + l2_sqrt_unexpanded_distances_t(const distances_config_t& config) + : l2_unexpanded_distances_t(config) + { + } + + void compute(value_t* out_dists) + { + l2_unexpanded_distances_t::compute(out_dists); + + uint64_t n = (uint64_t)this->config_->a_nrows * (uint64_t)this->config_->b_nrows; + // Sqrt Post-processing + raft::linalg::unaryOp( + out_dists, + out_dists, + n, + [] __device__(value_t input) { + int neg = input < 0 ? -1 : 1; + return raft::sqrt(abs(input) * neg); + }, + raft::resource::get_cuda_stream(this->config_->handle)); + } +}; + +template +class linf_unexpanded_distances_t : public distances_t { + public: + explicit linf_unexpanded_distances_t(const distances_config_t& config) + : config_(&config) + { + } + + void compute(value_t* out_dists) + { + unexpanded_lp_distances( + out_dists, config_, raft::absdiff_op(), raft::max_op(), raft::atomic_max_op()); + } + + private: + const distances_config_t* config_; +}; + +template +class canberra_unexpanded_distances_t : public distances_t { + public: + explicit canberra_unexpanded_distances_t(const distances_config_t& config) + : config_(&config) + { + } + + void compute(value_t* out_dists) + { + unexpanded_lp_distances( + out_dists, + config_, + [] __device__(value_t a, value_t b) { + value_t d = fabs(a) + fabs(b); + + // deal with potential for 0 in denominator by + // forcing 1/0 instead + return ((d != 0) * fabs(a - b)) / (d + (d == 0)); + }, + raft::add_op(), + raft::atomic_add_op()); + } + + private: + const distances_config_t* config_; +}; + +template +class lp_unexpanded_distances_t : public distances_t { + public: + explicit lp_unexpanded_distances_t(const distances_config_t& config, + value_t p_) + : config_(&config), p(p_) + { + } + + void compute(value_t* out_dists) + { + unexpanded_lp_distances( + out_dists, + config_, + raft::compose_op(raft::pow_const_op(p), raft::sub_op()), + raft::add_op(), + raft::atomic_add_op()); + + uint64_t n = (uint64_t)this->config_->a_nrows * (uint64_t)this->config_->b_nrows; + value_t one_over_p = value_t{1} / p; + raft::linalg::unaryOp(out_dists, + out_dists, + n, + raft::pow_const_op(one_over_p), + raft::resource::get_cuda_stream(config_->handle)); + } + + private: + const distances_config_t* config_; + value_t p; +}; + +template +class hamming_unexpanded_distances_t : public distances_t { + public: + explicit hamming_unexpanded_distances_t(const distances_config_t& config) + : config_(&config) + { + } + + void compute(value_t* out_dists) + { + unexpanded_lp_distances( + out_dists, config_, raft::notequal_op(), raft::add_op(), raft::atomic_add_op()); + + uint64_t n = (uint64_t)config_->a_nrows * (uint64_t)config_->b_nrows; + value_t n_cols = 1.0 / config_->a_ncols; + raft::linalg::unaryOp(out_dists, + out_dists, + n, + raft::mul_const_op(n_cols), + raft::resource::get_cuda_stream(config_->handle)); + } + + private: + const distances_config_t* config_; +}; + +template +class jensen_shannon_unexpanded_distances_t : public distances_t { + public: + explicit jensen_shannon_unexpanded_distances_t( + const distances_config_t& config) + : config_(&config) + { + } + + void compute(value_t* out_dists) + { + unexpanded_lp_distances( + out_dists, + config_, + [] __device__(value_t a, value_t b) { + value_t m = 0.5f * (a + b); + bool a_zero = a == 0; + bool b_zero = b == 0; + + value_t x = (!a_zero * m) / (a_zero + a); + value_t y = (!b_zero * m) / (b_zero + b); + + bool x_zero = x == 0; + bool y_zero = y == 0; + + return (-a * (!x_zero * log(x + x_zero))) + (-b * (!y_zero * log(y + y_zero))); + }, + raft::add_op(), + raft::atomic_add_op()); + + uint64_t n = (uint64_t)this->config_->a_nrows * (uint64_t)this->config_->b_nrows; + raft::linalg::unaryOp( + out_dists, + out_dists, + n, + [=] __device__(value_t input) { return raft::sqrt(0.5 * input); }, + raft::resource::get_cuda_stream(config_->handle)); + } + + private: + const distances_config_t* config_; +}; + +template +class kl_divergence_unexpanded_distances_t : public distances_t { + public: + explicit kl_divergence_unexpanded_distances_t( + const distances_config_t& config) + : config_(&config) + { + } + + void compute(value_t* out_dists) + { + rmm::device_uvector coo_rows(std::max(config_->b_nnz, config_->a_nnz), + raft::resource::get_cuda_stream(config_->handle)); + + raft::sparse::convert::csr_to_coo(config_->b_indptr, + config_->b_nrows, + coo_rows.data(), + config_->b_nnz, + raft::resource::get_cuda_stream(config_->handle)); + + balanced_coo_pairwise_generalized_spmv( + out_dists, + *config_, + coo_rows.data(), + [] __device__(value_t a, value_t b) { return a * log(a / b); }, + raft::add_op(), + raft::atomic_add_op()); + + uint64_t n = (uint64_t)this->config_->a_nrows * (uint64_t)this->config_->b_nrows; + raft::linalg::unaryOp(out_dists, + out_dists, + n, + raft::mul_const_op(0.5), + raft::resource::get_cuda_stream(config_->handle)); + } + + private: + const distances_config_t* config_; +}; + +} // END namespace sparse +} // END namespace detail +} // END namespace distance +} // END namespace cuvs diff --git a/cpp/src/distance/detail/sparse/utils.cuh b/cpp/src/distance/detail/sparse/utils.cuh new file mode 100644 index 000000000..dc7ae6df6 --- /dev/null +++ b/cpp/src/distance/detail/sparse/utils.cuh @@ -0,0 +1,171 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include +#include + +namespace cuvs { +namespace distance { +namespace detail { +namespace sparse { + +/** + * Computes the maximum number of columns that can be stored + * in shared memory in dense form with the given block size + * and precision. + * @return the maximum number of columns that can be stored in smem + */ +template +inline int max_cols_per_block() +{ + // max cols = (total smem available - cub reduction smem) + return (raft::getSharedMemPerBlock() - ((tpb / raft::warp_size()) * sizeof(value_t))) / + sizeof(value_t); +} + +template +RAFT_KERNEL faster_dot_on_csr_kernel(dot_t* __restrict__ dot, + const value_idx* __restrict__ indptr, + const value_idx* __restrict__ cols, + const value_t* __restrict__ A, + const value_t* __restrict__ B, + const value_idx nnz, + const value_idx n_rows, + const value_idx dim) +{ + auto vec_id = threadIdx.x; + auto lane_id = threadIdx.x & 0x1f; + + extern __shared__ char smem[]; + value_t* s_A = (value_t*)smem; + value_idx cur_row = -1; + + for (int row = blockIdx.x; row < n_rows; row += gridDim.x) { + for (int dot_id = blockIdx.y + indptr[row]; dot_id < indptr[row + 1]; dot_id += gridDim.y) { + if (dot_id >= nnz) { return; } + const value_idx col = cols[dot_id] * dim; + const value_t* __restrict__ B_col = B + col; + + if (threadIdx.x == 0) { dot[dot_id] = 0.0; } + __syncthreads(); + + if (cur_row != row) { + for (value_idx k = vec_id; k < dim; k += blockDim.x) { + s_A[k] = A[row * dim + k]; + } + cur_row = row; + } + + dot_t l_dot_ = 0.0; + for (value_idx k = vec_id; k < dim; k += blockDim.x) { + asm("prefetch.global.L2 [%0];" ::"l"(B_col + k + blockDim.x)); + if constexpr ((std::is_same_v && std::is_same_v)) { + l_dot_ += __half2float(s_A[k]) * __half2float(__ldcg(B_col + k)); + } else { + l_dot_ += s_A[k] * __ldcg(B_col + k); + } + } + + typedef cub::WarpReduce WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_storage; + dot_t warp_sum = WarpReduce(temp_storage).Sum(l_dot_); + + if (lane_id == 0) { atomicAdd_block(dot + dot_id, warp_sum); } + } + } +} + +template +void faster_dot_on_csr(raft::resources const& handle, + dot_t* dot, + const value_idx nnz, + const value_idx* indptr, + const value_idx* cols, + const value_t* A, + const value_t* B, + const value_idx n_rows, + const value_idx dim) +{ + if (nnz == 0 || n_rows == 0) return; + + auto stream = raft::resource::get_cuda_stream(handle); + + constexpr value_idx MAX_ROW_PER_ITER = 500; + int dev_id, sm_count, blocks_per_sm; + + const int smem_size = dim * sizeof(value_t); + cudaGetDevice(&dev_id); + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + + if (dim < 128) { + constexpr int tpb = 64; + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &blocks_per_sm, faster_dot_on_csr_kernel, tpb, smem_size); + auto block_x = std::min(n_rows, MAX_ROW_PER_ITER); + auto block_y = + (std::min(value_idx(blocks_per_sm * sm_count * 16), nnz) + block_x - 1) / block_x; + dim3 blocks(block_x, block_y, 1); + + faster_dot_on_csr_kernel + <<>>(dot, indptr, cols, A, B, nnz, n_rows, dim); + + } else if (dim < 256) { + constexpr int tpb = 128; + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &blocks_per_sm, faster_dot_on_csr_kernel, tpb, smem_size); + auto block_x = std::min(n_rows, MAX_ROW_PER_ITER); + auto block_y = + (std::min(value_idx(blocks_per_sm * sm_count * 16), nnz) + block_x - 1) / block_x; + dim3 blocks(block_x, block_y, 1); + + faster_dot_on_csr_kernel + <<>>(dot, indptr, cols, A, B, nnz, n_rows, dim); + } else if (dim < 512) { + constexpr int tpb = 256; + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &blocks_per_sm, faster_dot_on_csr_kernel, tpb, smem_size); + auto block_x = std::min(n_rows, MAX_ROW_PER_ITER); + auto block_y = + (std::min(value_idx(blocks_per_sm * sm_count * 16), nnz) + block_x - 1) / block_x; + dim3 blocks(block_x, block_y, 1); + + faster_dot_on_csr_kernel + <<>>(dot, indptr, cols, A, B, nnz, n_rows, dim); + } else { + constexpr int tpb = 512; + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &blocks_per_sm, faster_dot_on_csr_kernel, tpb, smem_size); + auto block_x = std::min(n_rows, MAX_ROW_PER_ITER); + auto block_y = + (std::min(value_idx(blocks_per_sm * sm_count * 16), nnz) + block_x - 1) / block_x; + dim3 blocks(block_x, block_y, 1); + + faster_dot_on_csr_kernel + <<>>(dot, indptr, cols, A, B, nnz, n_rows, dim); + } + + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +} // namespace sparse +} // namespace detail +} // namespace distance +} // namespace cuvs diff --git a/cpp/src/distance/sparse_distance.cu b/cpp/src/distance/sparse_distance.cu new file mode 100644 index 000000000..338c4e908 --- /dev/null +++ b/cpp/src/distance/sparse_distance.cu @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "sparse_distance.cuh" + +namespace cuvs { +namespace distance { + +template +void pairwise_distance( + raft::resources const& handle, + raft::device_csr_matrix_view x, + raft::device_csr_matrix_view y, + raft::device_matrix_view dist, + cuvs::distance::DistanceType metric, + float metric_arg = 2.0f) +{ + auto x_structure = x.structure_view(); + auto y_structure = y.structure_view(); + + RAFT_EXPECTS(x_structure.get_n_cols() == y_structure.get_n_cols(), + "Number of columns must be equal"); + + RAFT_EXPECTS(dist.extent(0) == x_structure.get_n_rows(), + "Number of rows in output must be equal to " + "number of rows in X"); + RAFT_EXPECTS(dist.extent(1) == y_structure.get_n_rows(), + "Number of columns in output must be equal to " + "number of rows in Y"); + + detail::sparse::distances_config_t input_config(handle); + input_config.a_nrows = x_structure.get_n_rows(); + input_config.a_ncols = x_structure.get_n_cols(); + input_config.a_nnz = x_structure.get_nnz(); + input_config.a_indptr = const_cast(x_structure.get_indptr().data()); + input_config.a_indices = const_cast(x_structure.get_indices().data()); + input_config.a_data = const_cast(x.get_elements().data()); + + input_config.b_nrows = y_structure.get_n_rows(); + input_config.b_ncols = y_structure.get_n_cols(); + input_config.b_nnz = y_structure.get_nnz(); + input_config.b_indptr = const_cast(y_structure.get_indptr().data()); + input_config.b_indices = const_cast(y_structure.get_indices().data()); + input_config.b_data = const_cast(y.get_elements().data()); + + pairwiseDistance(dist.data_handle(), input_config, metric, metric_arg); +} + +void pairwise_distance(raft::resources const& handle, + raft::device_csr_matrix_view x, + raft::device_csr_matrix_view y, + raft::device_matrix_view dist, + cuvs::distance::DistanceType metric, + float metric_arg) +{ + pairwise_distance(handle, x, y, dist, metric, metric_arg); +} + +void pairwise_distance(raft::resources const& handle, + raft::device_csr_matrix_view x, + raft::device_csr_matrix_view y, + raft::device_matrix_view dist, + cuvs::distance::DistanceType metric, + float metric_arg) +{ + pairwise_distance(handle, x, y, dist, metric, metric_arg); +} +} // namespace distance +} // namespace cuvs diff --git a/cpp/src/distance/sparse_distance.cuh b/cpp/src/distance/sparse_distance.cuh new file mode 100644 index 000000000..0d6dc0e6f --- /dev/null +++ b/cpp/src/distance/sparse_distance.cuh @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "detail/sparse/bin_distance.cuh" +#include "detail/sparse/common.hpp" +#include "detail/sparse/ip_distance.cuh" +#include "detail/sparse/l2_distance.cuh" +#include "detail/sparse/lp_distance.cuh" + +#include + +#include + +#include + +namespace cuvs { +namespace distance { +/** + * Compute pairwise distances between A and B, using the provided + * input configuration and distance function. + * + * @tparam value_idx index type + * @tparam value_t value type + * @param[out] out dense output array (size A.nrows * B.nrows) + * @param[in] input_config input argument configuration + * @param[in] metric distance metric to use + * @param[in] metric_arg metric argument (used for Minkowski distance) + */ +template +void pairwiseDistance(value_t* out, + detail::sparse::distances_config_t input_config, + cuvs::distance::DistanceType metric, + float metric_arg) +{ + switch (metric) { + case cuvs::distance::DistanceType::L2Expanded: + detail::sparse::l2_expanded_distances_t(input_config).compute(out); + break; + case cuvs::distance::DistanceType::L2SqrtExpanded: + detail::sparse::l2_sqrt_expanded_distances_t(input_config).compute(out); + break; + case cuvs::distance::DistanceType::InnerProduct: + detail::sparse::ip_distances_t(input_config).compute(out); + break; + case cuvs::distance::DistanceType::L2Unexpanded: + detail::sparse::l2_unexpanded_distances_t(input_config).compute(out); + break; + case cuvs::distance::DistanceType::L2SqrtUnexpanded: + detail::sparse::l2_sqrt_unexpanded_distances_t(input_config).compute(out); + break; + case cuvs::distance::DistanceType::L1: + detail::sparse::l1_unexpanded_distances_t(input_config).compute(out); + break; + case cuvs::distance::DistanceType::LpUnexpanded: + detail::sparse::lp_unexpanded_distances_t(input_config, metric_arg) + .compute(out); + break; + case cuvs::distance::DistanceType::Linf: + detail::sparse::linf_unexpanded_distances_t(input_config).compute(out); + break; + case cuvs::distance::DistanceType::Canberra: + detail::sparse::canberra_unexpanded_distances_t(input_config) + .compute(out); + break; + case cuvs::distance::DistanceType::JaccardExpanded: + detail::sparse::jaccard_expanded_distances_t(input_config).compute(out); + break; + case cuvs::distance::DistanceType::CosineExpanded: + detail::sparse::cosine_expanded_distances_t(input_config).compute(out); + break; + case cuvs::distance::DistanceType::HellingerExpanded: + detail::sparse::hellinger_expanded_distances_t(input_config).compute(out); + break; + case cuvs::distance::DistanceType::DiceExpanded: + detail::sparse::dice_expanded_distances_t(input_config).compute(out); + break; + case cuvs::distance::DistanceType::CorrelationExpanded: + detail::sparse::correlation_expanded_distances_t(input_config) + .compute(out); + break; + case cuvs::distance::DistanceType::RusselRaoExpanded: + detail::sparse::russelrao_expanded_distances_t(input_config).compute(out); + break; + case cuvs::distance::DistanceType::HammingUnexpanded: + detail::sparse::hamming_unexpanded_distances_t(input_config).compute(out); + break; + case cuvs::distance::DistanceType::JensenShannon: + detail::sparse::jensen_shannon_unexpanded_distances_t(input_config) + .compute(out); + break; + case cuvs::distance::DistanceType::KLDivergence: + detail::sparse::kl_divergence_unexpanded_distances_t(input_config) + .compute(out); + break; + + default: THROW("Unsupported distance: %d", metric); + } +} +}; // namespace distance +}; // namespace cuvs diff --git a/cpp/src/neighbors/detail/sparse_knn.cuh b/cpp/src/neighbors/detail/sparse_knn.cuh new file mode 100644 index 000000000..9c8e971b9 --- /dev/null +++ b/cpp/src/neighbors/detail/sparse_knn.cuh @@ -0,0 +1,437 @@ +/* + * Copyright (c) 2020-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "../../distance/sparse_distance.cuh" +#include "knn_merge_parts.cuh" +#include + +#include +#include + +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include + +namespace cuvs::neighbors::detail { + +template +struct csr_batcher_t { + csr_batcher_t(value_idx batch_size, + value_idx n_rows, + const value_idx* csr_indptr, + const value_idx* csr_indices, + const value_t* csr_data) + : batch_start_(0), + batch_stop_(0), + batch_rows_(0), + total_rows_(n_rows), + batch_size_(batch_size), + csr_indptr_(csr_indptr), + csr_indices_(csr_indices), + csr_data_(csr_data), + batch_csr_start_offset_(0), + batch_csr_stop_offset_(0) + { + } + + void set_batch(int batch_num) + { + batch_start_ = batch_num * batch_size_; + batch_stop_ = batch_start_ + batch_size_ - 1; // zero-based indexing + + if (batch_stop_ >= total_rows_) batch_stop_ = total_rows_ - 1; // zero-based indexing + + batch_rows_ = (batch_stop_ - batch_start_) + 1; + } + + value_idx get_batch_csr_indptr_nnz(value_idx* batch_indptr, cudaStream_t stream) + { + raft::sparse::op::csr_row_slice_indptr(batch_start_, + batch_stop_, + csr_indptr_, + batch_indptr, + &batch_csr_start_offset_, + &batch_csr_stop_offset_, + stream); + + return batch_csr_stop_offset_ - batch_csr_start_offset_; + } + + void get_batch_csr_indices_data(value_idx* csr_indices, value_t* csr_data, cudaStream_t stream) + { + raft::sparse::op::csr_row_slice_populate(batch_csr_start_offset_, + batch_csr_stop_offset_, + csr_indices_, + csr_data_, + csr_indices, + csr_data, + stream); + } + + value_idx batch_rows() const { return batch_rows_; } + + value_idx batch_start() const { return batch_start_; } + + value_idx batch_stop() const { return batch_stop_; } + + private: + value_idx batch_size_; + value_idx batch_start_; + value_idx batch_stop_; + value_idx batch_rows_; + + value_idx total_rows_; + + const value_idx* csr_indptr_; + const value_idx* csr_indices_; + const value_t* csr_data_; + + value_idx batch_csr_start_offset_; + value_idx batch_csr_stop_offset_; +}; + +template +class sparse_knn_t { + public: + sparse_knn_t(const value_idx* idxIndptr_, + const value_idx* idxIndices_, + const value_t* idxData_, + size_t idxNNZ_, + int n_idx_rows_, + int n_idx_cols_, + const value_idx* queryIndptr_, + const value_idx* queryIndices_, + const value_t* queryData_, + size_t queryNNZ_, + int n_query_rows_, + int n_query_cols_, + value_idx* output_indices_, + value_t* output_dists_, + int k_, + raft::resources const& handle_, + size_t batch_size_index_ = 2 << 14, // approx 1M + size_t batch_size_query_ = 2 << 14, + cuvs::distance::DistanceType metric_ = cuvs::distance::DistanceType::L2Expanded, + float metricArg_ = 0) + : idxIndptr(idxIndptr_), + idxIndices(idxIndices_), + idxData(idxData_), + idxNNZ(idxNNZ_), + n_idx_rows(n_idx_rows_), + n_idx_cols(n_idx_cols_), + queryIndptr(queryIndptr_), + queryIndices(queryIndices_), + queryData(queryData_), + queryNNZ(queryNNZ_), + n_query_rows(n_query_rows_), + n_query_cols(n_query_cols_), + output_indices(output_indices_), + output_dists(output_dists_), + k(k_), + handle(handle_), + batch_size_index(batch_size_index_), + batch_size_query(batch_size_query_), + metric(metric_), + metricArg(metricArg_) + { + } + + void run() + { + using namespace raft::sparse; + + int n_batches_query = raft::ceildiv((size_t)n_query_rows, batch_size_query); + csr_batcher_t query_batcher( + batch_size_query, n_query_rows, queryIndptr, queryIndices, queryData); + + size_t rows_processed = 0; + + for (int i = 0; i < n_batches_query; i++) { + /** + * Compute index batch info + */ + query_batcher.set_batch(i); + + /** + * Slice CSR to rows in batch + */ + + rmm::device_uvector query_batch_indptr(query_batcher.batch_rows() + 1, + raft::resource::get_cuda_stream(handle)); + + value_idx n_query_batch_nnz = query_batcher.get_batch_csr_indptr_nnz( + query_batch_indptr.data(), raft::resource::get_cuda_stream(handle)); + + rmm::device_uvector query_batch_indices(n_query_batch_nnz, + raft::resource::get_cuda_stream(handle)); + rmm::device_uvector query_batch_data(n_query_batch_nnz, + raft::resource::get_cuda_stream(handle)); + + query_batcher.get_batch_csr_indices_data(query_batch_indices.data(), + query_batch_data.data(), + raft::resource::get_cuda_stream(handle)); + + // A 3-partition temporary merge space to scale the batching. 2 parts for subsequent + // batches and 1 space for the results of the merge, which get copied back to the top + rmm::device_uvector merge_buffer_indices(0, + raft::resource::get_cuda_stream(handle)); + rmm::device_uvector merge_buffer_dists(0, raft::resource::get_cuda_stream(handle)); + + value_t* dists_merge_buffer_ptr; + value_idx* indices_merge_buffer_ptr; + + int n_batches_idx = raft::ceildiv((size_t)n_idx_rows, batch_size_index); + csr_batcher_t idx_batcher( + batch_size_index, n_idx_rows, idxIndptr, idxIndices, idxData); + + for (int j = 0; j < n_batches_idx; j++) { + idx_batcher.set_batch(j); + + merge_buffer_indices.resize(query_batcher.batch_rows() * k * 3, + raft::resource::get_cuda_stream(handle)); + merge_buffer_dists.resize(query_batcher.batch_rows() * k * 3, + raft::resource::get_cuda_stream(handle)); + + /** + * Slice CSR to rows in batch + */ + rmm::device_uvector idx_batch_indptr(idx_batcher.batch_rows() + 1, + raft::resource::get_cuda_stream(handle)); + rmm::device_uvector idx_batch_indices(0, + raft::resource::get_cuda_stream(handle)); + rmm::device_uvector idx_batch_data(0, raft::resource::get_cuda_stream(handle)); + + value_idx idx_batch_nnz = idx_batcher.get_batch_csr_indptr_nnz( + idx_batch_indptr.data(), raft::resource::get_cuda_stream(handle)); + + idx_batch_indices.resize(idx_batch_nnz, raft::resource::get_cuda_stream(handle)); + idx_batch_data.resize(idx_batch_nnz, raft::resource::get_cuda_stream(handle)); + + idx_batcher.get_batch_csr_indices_data( + idx_batch_indices.data(), idx_batch_data.data(), raft::resource::get_cuda_stream(handle)); + + /** + * Compute distances + */ + uint64_t dense_size = + (uint64_t)idx_batcher.batch_rows() * (uint64_t)query_batcher.batch_rows(); + rmm::device_uvector batch_dists(dense_size, + raft::resource::get_cuda_stream(handle)); + + RAFT_CUDA_TRY(cudaMemset(batch_dists.data(), 0, batch_dists.size() * sizeof(value_t))); + + compute_distances(idx_batcher, + query_batcher, + idx_batch_nnz, + n_query_batch_nnz, + idx_batch_indptr.data(), + idx_batch_indices.data(), + idx_batch_data.data(), + query_batch_indptr.data(), + query_batch_indices.data(), + query_batch_data.data(), + batch_dists.data()); + + // Build batch indices array + rmm::device_uvector batch_indices(batch_dists.size(), + raft::resource::get_cuda_stream(handle)); + + // populate batch indices array + value_idx batch_rows = query_batcher.batch_rows(), batch_cols = idx_batcher.batch_rows(); + + iota_fill( + batch_indices.data(), batch_rows, batch_cols, raft::resource::get_cuda_stream(handle)); + + /** + * Perform k-selection on batch & merge with other k-selections + */ + size_t merge_buffer_offset = batch_rows * k; + dists_merge_buffer_ptr = merge_buffer_dists.data() + merge_buffer_offset; + indices_merge_buffer_ptr = merge_buffer_indices.data() + merge_buffer_offset; + + perform_k_selection(idx_batcher, + query_batcher, + batch_dists.data(), + batch_indices.data(), + dists_merge_buffer_ptr, + indices_merge_buffer_ptr); + + value_t* dists_merge_buffer_tmp_ptr = dists_merge_buffer_ptr; + value_idx* indices_merge_buffer_tmp_ptr = indices_merge_buffer_ptr; + + // Merge results of difference batches if necessary + if (idx_batcher.batch_start() > 0) { + size_t merge_buffer_tmp_out = batch_rows * k * 2; + dists_merge_buffer_tmp_ptr = merge_buffer_dists.data() + merge_buffer_tmp_out; + indices_merge_buffer_tmp_ptr = merge_buffer_indices.data() + merge_buffer_tmp_out; + + merge_batches(idx_batcher, + query_batcher, + merge_buffer_dists.data(), + merge_buffer_indices.data(), + dists_merge_buffer_tmp_ptr, + indices_merge_buffer_tmp_ptr); + } + + // copy merged output back into merge buffer partition for next iteration + raft::copy_async(merge_buffer_indices.data(), + indices_merge_buffer_tmp_ptr, + batch_rows * k, + raft::resource::get_cuda_stream(handle)); + raft::copy_async(merge_buffer_dists.data(), + dists_merge_buffer_tmp_ptr, + batch_rows * k, + raft::resource::get_cuda_stream(handle)); + } + + // Copy final merged batch to output array + raft::copy_async(output_indices + (rows_processed * k), + merge_buffer_indices.data(), + query_batcher.batch_rows() * k, + raft::resource::get_cuda_stream(handle)); + raft::copy_async(output_dists + (rows_processed * k), + merge_buffer_dists.data(), + query_batcher.batch_rows() * k, + raft::resource::get_cuda_stream(handle)); + + rows_processed += query_batcher.batch_rows(); + } + } + + private: + void merge_batches(csr_batcher_t& idx_batcher, + csr_batcher_t& query_batcher, + value_t* merge_buffer_dists, + value_idx* merge_buffer_indices, + value_t* out_dists, + value_idx* out_indices) + { + // build translation buffer to shift resulting indices by the batch + std::vector id_ranges; + id_ranges.push_back(0); + id_ranges.push_back(idx_batcher.batch_start()); + + rmm::device_uvector trans(id_ranges.size(), raft::resource::get_cuda_stream(handle)); + raft::update_device( + trans.data(), id_ranges.data(), id_ranges.size(), raft::resource::get_cuda_stream(handle)); + + // combine merge buffers only if there's more than 1 partition to combine + cuvs::neighbors::detail::knn_merge_parts(merge_buffer_dists, + merge_buffer_indices, + out_dists, + out_indices, + query_batcher.batch_rows(), + 2, + k, + raft::resource::get_cuda_stream(handle), + trans.data()); + } + + void perform_k_selection(csr_batcher_t idx_batcher, + csr_batcher_t query_batcher, + value_t* batch_dists, + value_idx* batch_indices, + value_t* out_dists, + value_idx* out_indices) + { + // populate batch indices array + value_idx batch_rows = query_batcher.batch_rows(), batch_cols = idx_batcher.batch_rows(); + + // build translation buffer to shift resulting indices by the batch + std::vector id_ranges; + id_ranges.push_back(0); + id_ranges.push_back(idx_batcher.batch_start()); + + // in the case where the number of idx rows in the batch is < k, we + // want to adjust k. + value_idx n_neighbors = std::min(static_cast(k), batch_cols); + + bool ascending = cuvs::distance::is_min_close(metric); + + // kernel to slice first (min) k cols and copy into batched merge buffer + cuvs::selection::select_k( + handle, + raft::make_device_matrix_view(batch_dists, batch_rows, batch_cols), + raft::make_device_matrix_view( + batch_indices, batch_rows, batch_cols), + raft::make_device_matrix_view(out_dists, batch_rows, n_neighbors), + raft::make_device_matrix_view(out_indices, batch_rows, n_neighbors), + ascending, + true); + } + + void compute_distances(csr_batcher_t& idx_batcher, + csr_batcher_t& query_batcher, + size_t idx_batch_nnz, + size_t query_batch_nnz, + value_idx* idx_batch_indptr, + value_idx* idx_batch_indices, + value_t* idx_batch_data, + value_idx* query_batch_indptr, + value_idx* query_batch_indices, + value_t* query_batch_data, + value_t* batch_dists) + { + /** + * Compute distances + */ + cuvs::distance::detail::sparse::distances_config_t dist_config(handle); + dist_config.b_nrows = idx_batcher.batch_rows(); + dist_config.b_ncols = n_idx_cols; + dist_config.b_nnz = idx_batch_nnz; + + dist_config.b_indptr = idx_batch_indptr; + dist_config.b_indices = idx_batch_indices; + dist_config.b_data = idx_batch_data; + + dist_config.a_nrows = query_batcher.batch_rows(); + dist_config.a_ncols = n_query_cols; + dist_config.a_nnz = query_batch_nnz; + + dist_config.a_indptr = query_batch_indptr; + dist_config.a_indices = query_batch_indices; + dist_config.a_data = query_batch_data; + + cuvs::distance::pairwiseDistance(batch_dists, dist_config, metric, metricArg); + } + + const value_idx *idxIndptr, *idxIndices, *queryIndptr, *queryIndices; + value_idx* output_indices; + const value_t *idxData, *queryData; + value_t* output_dists; + + size_t idxNNZ, queryNNZ, batch_size_index, batch_size_query; + + cuvs::distance::DistanceType metric; + + float metricArg; + + int n_idx_rows, n_idx_cols, n_query_rows, n_query_cols, k; + + raft::resources const& handle; +}; + +}; // namespace cuvs::neighbors::detail diff --git a/cpp/src/neighbors/sparse_brute_force.cu b/cpp/src/neighbors/sparse_brute_force.cu new file mode 100644 index 000000000..e277961ec --- /dev/null +++ b/cpp/src/neighbors/sparse_brute_force.cu @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "detail/sparse_knn.cuh" + +namespace cuvs::neighbors::brute_force { +template +sparse_index::sparse_index(raft::resources const& res, + raft::device_csr_matrix_view dataset, + cuvs::distance::DistanceType metric, + T metric_arg) + : dataset_(dataset), metric_(metric), metric_arg_(metric_arg) +{ +} + +auto build(raft::resources const& handle, + raft::device_csr_matrix_view dataset, + cuvs::distance::DistanceType metric, + float metric_arg) -> cuvs::neighbors::brute_force::sparse_index +{ + return sparse_index(handle, dataset, metric, metric_arg); +} + +void search(raft::resources const& handle, + const sparse_search_params& params, + const sparse_index& index, + raft::device_csr_matrix_view query, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) +{ + auto idx_structure = index.dataset().structure_view(); + auto query_structure = query.structure_view(); + int k = neighbors.extent(1); + + detail::sparse_knn_t(idx_structure.get_indptr().data(), + idx_structure.get_indices().data(), + index.dataset().get_elements().data(), + idx_structure.get_nnz(), + idx_structure.get_n_rows(), + idx_structure.get_n_cols(), + query_structure.get_indptr().data(), + query_structure.get_indices().data(), + query.get_elements().data(), + query_structure.get_nnz(), + query_structure.get_n_rows(), + query_structure.get_n_cols(), + neighbors.data_handle(), + distances.data_handle(), + k, + handle, + params.batch_size_index, + params.batch_size_query, + index.metric(), + index.metric_arg()) + .run(); +} +} // namespace cuvs::neighbors::brute_force diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 7754a5043..286d721d7 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -94,7 +94,7 @@ endfunction() if(BUILD_TESTS) ConfigureTest( NAME NEIGHBORS_TEST PATH neighbors/brute_force.cu neighbors/brute_force_prefiltered.cu - neighbors/refine.cu GPUS 1 PERCENT 100 + neighbors/sparse_brute_force.cu neighbors/refine.cu GPUS 1 PERCENT 100 ) ConfigureTest( @@ -206,6 +206,7 @@ if(BUILD_TESTS) distance/dist_lp_unexp.cu distance/dist_russell_rao.cu distance/masked_nn.cu + distance/sparse_distance.cu sparse/neighbors/cross_component_nn.cu GPUS 1 diff --git a/cpp/test/distance/sparse_distance.cu b/cpp/test/distance/sparse_distance.cu new file mode 100644 index 000000000..f95487414 --- /dev/null +++ b/cpp/test/distance/sparse_distance.cu @@ -0,0 +1,850 @@ +/* + * Copyright (c) 2018-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" + +#include +#include +#include +#include + +#include + +#include +#include + +namespace cuvs { +namespace distance { + +using namespace raft; +using namespace raft::sparse; + +template +struct SparseDistanceInputs { + value_idx n_cols; + + std::vector indptr_h; + std::vector indices_h; + std::vector data_h; + + std::vector out_dists_ref_h; + + cuvs::distance::DistanceType metric; + + float metric_arg = 0.0; +}; + +template +::std::ostream& operator<<(::std::ostream& os, const SparseDistanceInputs& dims) +{ + return os; +} + +template +class SparseDistanceTest + : public ::testing::TestWithParam> { + public: + SparseDistanceTest() + : params(::testing::TestWithParam>::GetParam()), + indptr(0, resource::get_cuda_stream(handle)), + indices(0, resource::get_cuda_stream(handle)), + data(0, resource::get_cuda_stream(handle)), + out_dists(0, resource::get_cuda_stream(handle)), + out_dists_ref(0, resource::get_cuda_stream(handle)) + { + } + + void SetUp() override + { + make_data(); + + int out_size = static_cast(params.indptr_h.size() - 1) * + static_cast(params.indptr_h.size() - 1); + + out_dists.resize(out_size, resource::get_cuda_stream(handle)); + + auto out = raft::make_device_matrix_view( + out_dists.data(), + static_cast(params.indptr_h.size() - 1), + static_cast(params.indptr_h.size() - 1)); + + auto x_structure = raft::make_device_compressed_structure_view( + indptr.data(), + indices.data(), + static_cast(params.indptr_h.size() - 1), + params.n_cols, + static_cast(params.indices_h.size())); + auto x = raft::make_device_csr_matrix_view(data.data(), x_structure); + + cuvs::distance::pairwise_distance(handle, x, x, out, params.metric, params.metric_arg); + + RAFT_CUDA_TRY(cudaStreamSynchronize(resource::get_cuda_stream(handle))); + } + + void compare() + { + ASSERT_TRUE(devArrMatch(out_dists_ref.data(), + out_dists.data(), + params.out_dists_ref_h.size(), + CompareApprox(1e-3))); + } + + protected: + void make_data() + { + std::vector indptr_h = params.indptr_h; + std::vector indices_h = params.indices_h; + std::vector data_h = params.data_h; + + auto stream = resource::get_cuda_stream(handle); + indptr.resize(indptr_h.size(), stream); + indices.resize(indices_h.size(), stream); + data.resize(data_h.size(), stream); + + update_device(indptr.data(), indptr_h.data(), indptr_h.size(), stream); + update_device(indices.data(), indices_h.data(), indices_h.size(), stream); + update_device(data.data(), data_h.data(), data_h.size(), stream); + + std::vector out_dists_ref_h = params.out_dists_ref_h; + + out_dists_ref.resize((indptr_h.size() - 1) * (indptr_h.size() - 1), stream); + + update_device(out_dists_ref.data(), + out_dists_ref_h.data(), + out_dists_ref_h.size(), + resource::get_cuda_stream(handle)); + } + + raft::resources handle; + + // input data + rmm::device_uvector indptr, indices; + rmm::device_uvector data; + + // output data + rmm::device_uvector out_dists, out_dists_ref; + + SparseDistanceInputs params; +}; + +const std::vector> inputs_i32_f = { + {5, + {0, 0, 1, 2}, + + {1, 2}, + {0.5, 0.5}, + {0, 1, 1, 1, 0, 1, 1, 1, 0}, + cuvs::distance::DistanceType::CosineExpanded, + 0.0}, + {5, + {0, 0, 1, 2}, + + {1, 2}, + {1.0, 1.0}, + {0, 1, 1, 1, 0, 1, 1, 1, 0}, + cuvs::distance::DistanceType::JaccardExpanded, + 0.0}, + {2, + {0, 2, 4, 6, 8}, + {0, 1, 0, 1, 0, 1, 0, 1}, // indices + {1.0f, 3.0f, 1.0f, 5.0f, 50.0f, 28.0f, 16.0f, 2.0f}, + { + // dense output + 0.0, + 4.0, + 3026.0, + 226.0, + 4.0, + 0.0, + 2930.0, + 234.0, + 3026.0, + 2930.0, + 0.0, + 1832.0, + 226.0, + 234.0, + 1832.0, + 0.0, + }, + cuvs::distance::DistanceType::L2Expanded, + 0.0}, + {2, + {0, 2, 4, 6, 8}, + {0, 1, 0, 1, 0, 1, 0, 1}, + {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f}, + {5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0}, + cuvs::distance::DistanceType::InnerProduct, + 0.0}, + {2, + {0, 2, 4, 6, 8}, + {0, 1, 0, 1, 0, 1, 0, 1}, // indices + {1.0f, 3.0f, 1.0f, 5.0f, 50.0f, 28.0f, 16.0f, 2.0f}, + { + // dense output + 0.0, + 4.0, + 3026.0, + 226.0, + 4.0, + 0.0, + 2930.0, + 234.0, + 3026.0, + 2930.0, + 0.0, + 1832.0, + 226.0, + 234.0, + 1832.0, + 0.0, + }, + cuvs::distance::DistanceType::L2Unexpanded, + 0.0}, + + {10, + {0, 5, 11, 15, 20, 27, 32, 36, 43, 47, 50}, + {0, 1, 3, 6, 8, 0, 1, 2, 3, 5, 6, 1, 2, 4, 8, 0, 2, 3, 4, 7, 0, 1, 2, 3, 4, + 6, 8, 0, 1, 2, 5, 7, 1, 5, 8, 9, 0, 1, 2, 5, 6, 8, 9, 2, 4, 5, 7, 0, 3, 9}, // indices + {0.5438, 0.2695, 0.4377, 0.7174, 0.9251, 0.7648, 0.3322, 0.7279, 0.4131, 0.5167, + 0.8655, 0.0730, 0.0291, 0.9036, 0.7988, 0.5019, 0.7663, 0.2190, 0.8206, 0.3625, + 0.0411, 0.3995, 0.5688, 0.7028, 0.8706, 0.3199, 0.4431, 0.0535, 0.2225, 0.8853, + 0.1932, 0.3761, 0.3379, 0.1771, 0.2107, 0.228, 0.5279, 0.4885, 0.3495, 0.5079, + 0.2325, 0.2331, 0.3018, 0.6231, 0.2645, 0.8429, 0.6625, 0.0797, 0.2724, 0.4218}, + {0., 0.39419924, 0.54823225, 0.79593037, 0.45658883, 0.93634219, 0.58146987, 0.44940102, + 1., 0.76978799, 0.39419924, 0., 0.97577154, 0.48904013, 0.48300801, 0.45087445, + 0.73323749, 0.21050481, 0.54847744, 0.78021386, 0.54823225, 0.97577154, 0., 0.51413997, + 0.31195441, 0.96546343, 0.67534399, 0.81665436, 0.8321819, 1., 0.79593037, 0.48904013, + 0.51413997, 0., 0.28605559, 0.35772784, 1., 0.60889396, 0.43324829, 0.84923694, + 0.45658883, 0.48300801, 0.31195441, 0.28605559, 0., 0.58623212, 0.6745457, 0.60287165, + 0.67676228, 0.73155632, 0.93634219, 0.45087445, 0.96546343, 0.35772784, 0.58623212, 0., + 0.77917274, 0.48390993, 0.24558392, 0.99166225, 0.58146987, 0.73323749, 0.67534399, 1., + 0.6745457, 0.77917274, 0., 0.27605686, 0.76064776, 0.61547536, 0.44940102, 0.21050481, + 0.81665436, 0.60889396, 0.60287165, 0.48390993, 0.27605686, 0., 0.51360432, 0.68185144, + 1., 0.54847744, 0.8321819, 0.43324829, 0.67676228, 0.24558392, 0.76064776, 0.51360432, + 0., 1., 0.76978799, 0.78021386, 1., 0.84923694, 0.73155632, 0.99166225, + 0.61547536, 0.68185144, 1., 0.}, + cuvs::distance::DistanceType::CosineExpanded, + 0.0}, + + {10, + {0, 5, 11, 15, 20, 27, 32, 36, 43, 47, 50}, + {0, 1, 3, 6, 8, 0, 1, 2, 3, 5, 6, 1, 2, 4, 8, 0, 2, 3, 4, 7, 0, 1, 2, 3, 4, + 6, 8, 0, 1, 2, 5, 7, 1, 5, 8, 9, 0, 1, 2, 5, 6, 8, 9, 2, 4, 5, 7, 0, 3, 9}, // indices + {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., + 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., + 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, + {0.0, + 0.42857142857142855, + 0.7142857142857143, + 0.75, + 0.2857142857142857, + 0.75, + 0.7142857142857143, + 0.5, + 1.0, + 0.6666666666666666, + 0.42857142857142855, + 0.0, + 0.75, + 0.625, + 0.375, + 0.42857142857142855, + 0.75, + 0.375, + 0.75, + 0.7142857142857143, + 0.7142857142857143, + 0.75, + 0.0, + 0.7142857142857143, + 0.42857142857142855, + 0.7142857142857143, + 0.6666666666666666, + 0.625, + 0.6666666666666666, + 1.0, + 0.75, + 0.625, + 0.7142857142857143, + 0.0, + 0.5, + 0.5714285714285714, + 1.0, + 0.8, + 0.5, + 0.6666666666666666, + 0.2857142857142857, + 0.375, + 0.42857142857142855, + 0.5, + 0.0, + 0.6666666666666666, + 0.7777777777777778, + 0.4444444444444444, + 0.7777777777777778, + 0.75, + 0.75, + 0.42857142857142855, + 0.7142857142857143, + 0.5714285714285714, + 0.6666666666666666, + 0.0, + 0.7142857142857143, + 0.5, + 0.5, + 0.8571428571428571, + 0.7142857142857143, + 0.75, + 0.6666666666666666, + 1.0, + 0.7777777777777778, + 0.7142857142857143, + 0.0, + 0.42857142857142855, + 0.8571428571428571, + 0.8333333333333334, + 0.5, + 0.375, + 0.625, + 0.8, + 0.4444444444444444, + 0.5, + 0.42857142857142855, + 0.0, + 0.7777777777777778, + 0.75, + 1.0, + 0.75, + 0.6666666666666666, + 0.5, + 0.7777777777777778, + 0.5, + 0.8571428571428571, + 0.7777777777777778, + 0.0, + 1.0, + 0.6666666666666666, + 0.7142857142857143, + 1.0, + 0.6666666666666666, + 0.75, + 0.8571428571428571, + 0.8333333333333334, + 0.75, + 1.0, + 0.0}, + cuvs::distance::DistanceType::JaccardExpanded, + 0.0}, + + {10, + {0, 5, 11, 15, 20, 27, 32, 36, 43, 47, 50}, + {0, 1, 3, 6, 8, 0, 1, 2, 3, 5, 6, 1, 2, 4, 8, 0, 2, 3, 4, 7, 0, 1, 2, 3, 4, + 6, 8, 0, 1, 2, 5, 7, 1, 5, 8, 9, 0, 1, 2, 5, 6, 8, 9, 2, 4, 5, 7, 0, 3, 9}, // indices + {0.5438, 0.2695, 0.4377, 0.7174, 0.9251, 0.7648, 0.3322, 0.7279, 0.4131, 0.5167, + 0.8655, 0.0730, 0.0291, 0.9036, 0.7988, 0.5019, 0.7663, 0.2190, 0.8206, 0.3625, + 0.0411, 0.3995, 0.5688, 0.7028, 0.8706, 0.3199, 0.4431, 0.0535, 0.2225, 0.8853, + 0.1932, 0.3761, 0.3379, 0.1771, 0.2107, 0.228, 0.5279, 0.4885, 0.3495, 0.5079, + 0.2325, 0.2331, 0.3018, 0.6231, 0.2645, 0.8429, 0.6625, 0.0797, 0.2724, 0.4218}, + {0.0, + 3.3954660629919076, + 5.6469232737388815, + 6.373112846266441, + 4.0212880272531715, + 6.916281504639404, + 5.741508386786526, + 5.411470999663036, + 9.0, + 4.977014354725805, + 3.3954660629919076, + 0.0, + 7.56256082439209, + 5.540261147481582, + 4.832322929216881, + 4.62003193872216, + 6.498056792320361, + 4.309846252268695, + 6.317531174829905, + 6.016362684141827, + 5.6469232737388815, + 7.56256082439209, + 0.0, + 5.974878731322299, + 4.898357301336036, + 6.442097410320605, + 5.227077347287883, + 7.134101195584642, + 5.457753923371659, + 7.0, + 6.373112846266441, + 5.540261147481582, + 5.974878731322299, + 0.0, + 5.5507273748583, + 4.897749658726415, + 9.0, + 8.398776718824767, + 3.908281400328807, + 4.83431066343688, + 4.0212880272531715, + 4.832322929216881, + 4.898357301336036, + 5.5507273748583, + 0.0, + 6.632989819428174, + 7.438852294822894, + 5.6631570310967465, + 7.579428202635459, + 6.760811985364303, + 6.916281504639404, + 4.62003193872216, + 6.442097410320605, + 4.897749658726415, + 6.632989819428174, + 0.0, + 5.249404187382862, + 6.072559523278559, + 4.07661278488929, + 6.19678948003145, + 5.741508386786526, + 6.498056792320361, + 5.227077347287883, + 9.0, + 7.438852294822894, + 5.249404187382862, + 0.0, + 3.854811639654704, + 6.652724827169063, + 5.298236851430971, + 5.411470999663036, + 4.309846252268695, + 7.134101195584642, + 8.398776718824767, + 5.6631570310967465, + 6.072559523278559, + 3.854811639654704, + 0.0, + 7.529184598969917, + 6.903282911791188, + 9.0, + 6.317531174829905, + 5.457753923371659, + 3.908281400328807, + 7.579428202635459, + 4.07661278488929, + 6.652724827169063, + 7.529184598969917, + 0.0, + 7.0, + 4.977014354725805, + 6.016362684141827, + 7.0, + 4.83431066343688, + 6.760811985364303, + 6.19678948003145, + 5.298236851430971, + 6.903282911791188, + 7.0, + 0.0}, + cuvs::distance::DistanceType::Canberra, + 0.0}, + + {10, + {0, 5, 11, 15, 20, 27, 32, 36, 43, 47, 50}, + {0, 1, 3, 6, 8, 0, 1, 2, 3, 5, 6, 1, 2, 4, 8, 0, 2, 3, 4, 7, 0, 1, 2, 3, 4, + 6, 8, 0, 1, 2, 5, 7, 1, 5, 8, 9, 0, 1, 2, 5, 6, 8, 9, 2, 4, 5, 7, 0, 3, 9}, // indices + {0.5438, 0.2695, 0.4377, 0.7174, 0.9251, 0.7648, 0.3322, 0.7279, 0.4131, 0.5167, + 0.8655, 0.0730, 0.0291, 0.9036, 0.7988, 0.5019, 0.7663, 0.2190, 0.8206, 0.3625, + 0.0411, 0.3995, 0.5688, 0.7028, 0.8706, 0.3199, 0.4431, 0.0535, 0.2225, 0.8853, + 0.1932, 0.3761, 0.3379, 0.1771, 0.2107, 0.228, 0.5279, 0.4885, 0.3495, 0.5079, + 0.2325, 0.2331, 0.3018, 0.6231, 0.2645, 0.8429, 0.6625, 0.0797, 0.2724, 0.4218}, + {0.0, + 1.31462855332296, + 1.3690307816129905, + 1.698603990921237, + 1.3460470789553531, + 1.6636670712582544, + 1.2651744044972217, + 1.1938329352055201, + 1.8811409082590185, + 1.3653115050624267, + 1.31462855332296, + 0.0, + 1.9447722703291133, + 1.42818777206562, + 1.4685491458946494, + 1.3071999866010466, + 1.4988622861692171, + 0.9698559287406783, + 1.4972023224597841, + 1.5243383567266802, + 1.3690307816129905, + 1.9447722703291133, + 0.0, + 1.2748400840107568, + 1.0599569946448246, + 1.546591282841402, + 1.147526531928459, + 1.447002179128145, + 1.5982242387673176, + 1.3112533607072414, + 1.698603990921237, + 1.42818777206562, + 1.2748400840107568, + 0.0, + 1.038121552545461, + 1.011788365364402, + 1.3907391109256988, + 1.3128200942311496, + 1.19595706584447, + 1.3233328139624725, + 1.3460470789553531, + 1.4685491458946494, + 1.0599569946448246, + 1.038121552545461, + 0.0, + 1.3642741698145529, + 1.3493868683808095, + 1.394942694628328, + 1.572881849642552, + 1.380122665319464, + 1.6636670712582544, + 1.3071999866010466, + 1.546591282841402, + 1.011788365364402, + 1.3642741698145529, + 0.0, + 1.018961640373018, + 1.0114394258945634, + 0.8338711034820684, + 1.1247823842299223, + 1.2651744044972217, + 1.4988622861692171, + 1.147526531928459, + 1.3907391109256988, + 1.3493868683808095, + 1.018961640373018, + 0.0, + 0.7701238110357329, + 1.245486437864406, + 0.5551259549534626, + 1.1938329352055201, + 0.9698559287406783, + 1.447002179128145, + 1.3128200942311496, + 1.394942694628328, + 1.0114394258945634, + 0.7701238110357329, + 0.0, + 1.1886800117391216, + 1.0083692448135637, + 1.8811409082590185, + 1.4972023224597841, + 1.5982242387673176, + 1.19595706584447, + 1.572881849642552, + 0.8338711034820684, + 1.245486437864406, + 1.1886800117391216, + 0.0, + 1.3661374102525012, + 1.3653115050624267, + 1.5243383567266802, + 1.3112533607072414, + 1.3233328139624725, + 1.380122665319464, + 1.1247823842299223, + 0.5551259549534626, + 1.0083692448135637, + 1.3661374102525012, + 0.0}, + cuvs::distance::DistanceType::LpUnexpanded, + 2.0}, + + {10, + {0, 5, 11, 15, 20, 27, 32, 36, 43, 47, 50}, + {0, 1, 3, 6, 8, 0, 1, 2, 3, 5, 6, 1, 2, 4, 8, 0, 2, 3, 4, 7, 0, 1, 2, 3, 4, + 6, 8, 0, 1, 2, 5, 7, 1, 5, 8, 9, 0, 1, 2, 5, 6, 8, 9, 2, 4, 5, 7, 0, 3, 9}, // indices + {0.5438, 0.2695, 0.4377, 0.7174, 0.9251, 0.7648, 0.3322, 0.7279, 0.4131, 0.5167, + 0.8655, 0.0730, 0.0291, 0.9036, 0.7988, 0.5019, 0.7663, 0.2190, 0.8206, 0.3625, + 0.0411, 0.3995, 0.5688, 0.7028, 0.8706, 0.3199, 0.4431, 0.0535, 0.2225, 0.8853, + 0.1932, 0.3761, 0.3379, 0.1771, 0.2107, 0.228, 0.5279, 0.4885, 0.3495, 0.5079, + 0.2325, 0.2331, 0.3018, 0.6231, 0.2645, 0.8429, 0.6625, 0.0797, 0.2724, 0.4218}, + {0.0, + 0.9251771844789913, + 0.9036452083899731, + 0.9251771844789913, + 0.8706483735804971, + 0.9251771844789913, + 0.717493881903289, + 0.6920214832303888, + 0.9251771844789913, + 0.9251771844789913, + 0.9251771844789913, + 0.0, + 0.9036452083899731, + 0.8655339692155823, + 0.8706483735804971, + 0.8655339692155823, + 0.8655339692155823, + 0.6329837991017668, + 0.8655339692155823, + 0.8655339692155823, + 0.9036452083899731, + 0.9036452083899731, + 0.0, + 0.7988276152181608, + 0.7028075145996631, + 0.9036452083899731, + 0.9036452083899731, + 0.9036452083899731, + 0.8429599432532096, + 0.9036452083899731, + 0.9251771844789913, + 0.8655339692155823, + 0.7988276152181608, + 0.0, + 0.48376552205293305, + 0.8206394616536681, + 0.8206394616536681, + 0.8206394616536681, + 0.8429599432532096, + 0.8206394616536681, + 0.8706483735804971, + 0.8706483735804971, + 0.7028075145996631, + 0.48376552205293305, + 0.0, + 0.8706483735804971, + 0.8706483735804971, + 0.8706483735804971, + 0.8429599432532096, + 0.8706483735804971, + 0.9251771844789913, + 0.8655339692155823, + 0.9036452083899731, + 0.8206394616536681, + 0.8706483735804971, + 0.0, + 0.8853924473642432, + 0.535821510936138, + 0.6497196601457607, + 0.8853924473642432, + 0.717493881903289, + 0.8655339692155823, + 0.9036452083899731, + 0.8206394616536681, + 0.8706483735804971, + 0.8853924473642432, + 0.0, + 0.5279604218147174, + 0.6658348373853169, + 0.33799874888632914, + 0.6920214832303888, + 0.6329837991017668, + 0.9036452083899731, + 0.8206394616536681, + 0.8706483735804971, + 0.535821510936138, + 0.5279604218147174, + 0.0, + 0.662579808115858, + 0.5079750812968089, + 0.9251771844789913, + 0.8655339692155823, + 0.8429599432532096, + 0.8429599432532096, + 0.8429599432532096, + 0.6497196601457607, + 0.6658348373853169, + 0.662579808115858, + 0.0, + 0.8429599432532096, + 0.9251771844789913, + 0.8655339692155823, + 0.9036452083899731, + 0.8206394616536681, + 0.8706483735804971, + 0.8853924473642432, + 0.33799874888632914, + 0.5079750812968089, + 0.8429599432532096, + 0.0}, + cuvs::distance::DistanceType::Linf, + 0.0}, + + {15, + {0, 5, 8, 9, 15, 20, 26, 31, 34, 38, 45}, + {0, 1, 5, 6, 9, 1, 4, 14, 7, 3, 4, 7, 9, 11, 14, 0, 3, 7, 8, 12, 0, 2, 5, + 7, 8, 14, 4, 9, 10, 11, 13, 4, 10, 14, 5, 6, 8, 9, 0, 2, 3, 4, 6, 10, 11}, + {0.13537497, 0.51440163, 0.17231936, 0.02417618, 0.15372786, 0.17760507, 0.73789274, 0.08450219, + 1., 0.20184723, 0.18036963, 0.12581403, 0.13867603, 0.24040536, 0.11288773, 0.00290246, + 0.09120187, 0.31190555, 0.43245423, 0.16153588, 0.3233026, 0.05279589, 0.1387149, 0.05962761, + 0.41751856, 0.00804045, 0.03262381, 0.27507131, 0.37245804, 0.16378881, 0.15605804, 0.3867739, + 0.24908977, 0.36413632, 0.37643732, 0.28910679, 0.0198409, 0.31461499, 0.24412279, 0.08327667, + 0.04444576, 0.05047969, 0.26190054, 0.2077349, 0.10803964}, + {1.05367121e-08, 8.35309089e-01, 1.00000000e+00, 9.24116813e-01, + 9.90039274e-01, 7.97613546e-01, 8.91271059e-01, 1.00000000e+00, + 6.64669302e-01, 8.59439512e-01, 8.35309089e-01, 1.05367121e-08, + 1.00000000e+00, 7.33151506e-01, 1.00000000e+00, 9.86880955e-01, + 9.19154851e-01, 5.38849774e-01, 1.00000000e+00, 8.98332369e-01, + 1.00000000e+00, 1.00000000e+00, 0.00000000e+00, 8.03303970e-01, + 6.64465915e-01, 8.69374690e-01, 1.00000000e+00, 1.00000000e+00, + 1.00000000e+00, 1.00000000e+00, 9.24116813e-01, 7.33151506e-01, + 8.03303970e-01, 0.00000000e+00, 8.16225843e-01, 9.39818306e-01, + 7.27700415e-01, 7.30155528e-01, 8.89451011e-01, 8.05419635e-01, + 9.90039274e-01, 1.00000000e+00, 6.64465915e-01, 8.16225843e-01, + 0.00000000e+00, 6.38804490e-01, 1.00000000e+00, 1.00000000e+00, + 9.52559809e-01, 9.53789212e-01, 7.97613546e-01, 9.86880955e-01, + 8.69374690e-01, 9.39818306e-01, 6.38804490e-01, 0.0, + 1.00000000e+00, 9.72569112e-01, 8.24907516e-01, 8.07933016e-01, + 8.91271059e-01, 9.19154851e-01, 1.00000000e+00, 7.27700415e-01, + 1.00000000e+00, 1.00000000e+00, 0.00000000e+00, 7.63596268e-01, + 8.40131263e-01, 7.40428532e-01, 1.00000000e+00, 5.38849774e-01, + 1.00000000e+00, 7.30155528e-01, 1.00000000e+00, 9.72569112e-01, + 7.63596268e-01, 0.00000000e+00, 1.00000000e+00, 7.95485011e-01, + 6.64669302e-01, 1.00000000e+00, 1.00000000e+00, 8.89451011e-01, + 9.52559809e-01, 8.24907516e-01, 8.40131263e-01, 1.00000000e+00, + 0.00000000e+00, 8.51370877e-01, 8.59439512e-01, 8.98332369e-01, + 1.00000000e+00, 8.05419635e-01, 9.53789212e-01, 8.07933016e-01, + 7.40428532e-01, 7.95485011e-01, 8.51370877e-01, 1.49011612e-08}, + // Dataset is L1 normalized into pdfs + cuvs::distance::DistanceType::HellingerExpanded, + 0.0}, + + {4, + {0, 1, 1, 2, 4}, + {3, 2, 0, 1}, // indices + {0.99296, 0.42180, 0.11687, 0.305869}, + { + // dense output + 0.0, + 0.99296, + 1.41476, + 1.415707, + 0.99296, + 0.0, + 0.42180, + 0.42274, + 1.41476, + 0.42180, + 0.0, + 0.84454, + 1.41570, + 0.42274, + 0.84454, + 0.0, + }, + cuvs::distance::DistanceType::L1, + 0.0}, + {5, + {0, 3, 8, 12, 16, 20, 25, 30, 35, 40, 45}, + {0, 3, 4, 0, 1, 2, 3, 4, 1, 2, 3, 4, 0, 2, 3, 4, 0, 1, 3, 4, 0, 1, 2, + 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4}, + {0.70862347, 0.8232774, 0.12108795, 0.84527547, 0.94937088, 0.03258545, 0.99584118, 0.76835667, + 0.34426657, 0.2357925, 0.01274851, 0.11422017, 0.3437756, 0.31967718, 0.5956055, 0.31610373, + 0.04147273, 0.03724415, 0.21515727, 0.04751052, 0.50283183, 0.99957274, 0.01395933, 0.96032529, + 0.88438711, 0.46095378, 0.27432481, 0.54294211, 0.54280225, 0.59503329, 0.61364678, 0.22837736, + 0.56609561, 0.29809423, 0.76736686, 0.56460608, 0.98165371, 0.02140123, 0.19881268, 0.26057815, + 0.31648823, 0.89874295, 0.27366735, 0.5119944, 0.11416134}, + {// dense output + 0., 0.48769777, 1.88014197, 0.26127048, 0.26657011, 0.7874794, 0.76962708, 1.122858, + 1.1232498, 1.08166081, 0.48769777, 0., 1.31332116, 0.98318907, 0.42661815, 0.09279052, + 1.35187836, 1.38429055, 0.40658897, 0.56136388, 1.88014197, 1.31332116, 0., 1.82943642, + 1.54826077, 1.05918884, 1.59360067, 1.34698954, 0.60215168, 0.46993848, 0.26127048, 0.98318907, + 1.82943642, 0., 0.29945563, 1.08494093, 0.22934281, 0.82801925, 1.74288748, 1.50610116, + 0.26657011, 0.42661815, 1.54826077, 0.29945563, 0., 0.45060069, 0.77814948, 1.45245711, + 1.18328348, 0.82486987, 0.7874794, 0.09279052, 1.05918884, 1.08494093, 0.45060069, 0., + 1.29899154, 1.40683824, 0.48505269, 0.53862363, 0.76962708, 1.35187836, 1.59360067, 0.22934281, + 0.77814948, 1.29899154, 0., 0.33202426, 1.92108999, 1.88812175, 1.122858, 1.38429055, + 1.34698954, 0.82801925, 1.45245711, 1.40683824, 0.33202426, 0., 1.47318624, 1.92660889, + 1.1232498, 0.40658897, 0.60215168, 1.74288748, 1.18328348, 0.48505269, 1.92108999, 1.47318624, + 0., 0.24992619, 1.08166081, 0.56136388, 0.46993848, 1.50610116, 0.82486987, 0.53862363, + 1.88812175, 1.92660889, 0.24992619, 0.}, + cuvs::distance::DistanceType::CorrelationExpanded, + 0.0}, + {5, + {0, 1, 2, 4, 4, 5, 6, 7, 9, 9, 10}, + {1, 4, 0, 4, 1, 3, 0, 1, 3, 0}, + {1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, + {// dense output + 0., 1., 1., 1., 0.8, 1., 1., 0.8, 1., 1., 1., 0., 0.8, 1., 1., 1., 1., 1., 1., 1., + 1., 0.8, 0., 1., 1., 1., 0.8, 1., 1., 0.8, 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., + 0.8, 1., 1., 1., 0., 1., 1., 0.8, 1., 1., 1., 1., 1., 1., 1., 0., 1., 0.8, 1., 1., + 1., 1., 0.8, 1., 1., 1., 0., 1., 1., 0.8, 0.8, 1., 1., 1., 0.8, 0.8, 1., 0., 1., 1., + 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0.8, 1., 1., 1., 0.8, 1., 1., 0.}, + cuvs::distance::DistanceType::RusselRaoExpanded, + 0.0}, + {5, + {0, 1, 1, 3, 3, 4, 4, 6, 9, 10, 10}, + {0, 3, 4, 4, 2, 3, 0, 2, 3, 2}, + {1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, + {// dense output + 0., 0.2, 0.6, 0.2, 0.4, 0.2, 0.6, 0.4, 0.4, 0.2, 0.2, 0., 0.4, 0., 0.2, 0., 0.4, + 0.6, 0.2, 0., 0.6, 0.4, 0., 0.4, 0.2, 0.4, 0.4, 0.6, 0.6, 0.4, 0.2, 0., 0.4, 0., + 0.2, 0., 0.4, 0.6, 0.2, 0., 0.4, 0.2, 0.2, 0.2, 0., 0.2, 0.6, 0.8, 0.4, 0.2, 0.2, + 0., 0.4, 0., 0.2, 0., 0.4, 0.6, 0.2, 0., 0.6, 0.4, 0.4, 0.4, 0.6, 0.4, 0., 0.2, + 0.2, 0.4, 0.4, 0.6, 0.6, 0.6, 0.8, 0.6, 0.2, 0., 0.4, 0.6, 0.4, 0.2, 0.6, 0.2, 0.4, + 0.2, 0.2, 0.4, 0., 0.2, 0.2, 0., 0.4, 0., 0.2, 0., 0.4, 0.6, 0.2, 0.}, + cuvs::distance::DistanceType::HammingUnexpanded, + 0.0}, + {3, + {0, 1, 2}, + {0, 1}, + {1.0, 1.0}, + {0.0, 0.83255, 0.83255, 0.0}, + cuvs::distance::DistanceType::JensenShannon, + 0.0}, + {2, + {0, 1, 3}, + {0, 0, 1}, + {1.0, 0.5, 0.5}, + {0, 0.4645014, 0.4645014, 0}, + cuvs::distance::DistanceType::JensenShannon, + 0.0}, + {3, + {0, 1, 2}, + {0, 0}, + {1.0, 1.0}, + {0.0, 0.0, 0.0, 0.0}, + cuvs::distance::DistanceType::JensenShannon, + 0.0}, + + {3, + {0, 1, 2}, + {0, 1}, + {1.0, 1.0}, + {0.0, 1.0, 1.0, 0.0}, + cuvs::distance::DistanceType::DiceExpanded, + 0.0}, + {3, + {0, 1, 3}, + {0, 0, 1}, + {1.0, 1.0, 1.0}, + {0, 0.333333, 0.333333, 0}, + cuvs::distance::DistanceType::DiceExpanded, + 0.0}, + +}; + +typedef SparseDistanceTest SparseDistanceTestF; +TEST_P(SparseDistanceTestF, Result) { compare(); } +INSTANTIATE_TEST_CASE_P(SparseDistanceTests, + SparseDistanceTestF, + ::testing::ValuesIn(inputs_i32_f)); + +} // end namespace distance +} // end namespace cuvs diff --git a/cpp/test/neighbors/sparse_brute_force.cu b/cpp/test/neighbors/sparse_brute_force.cu new file mode 100644 index 000000000..cb68989d4 --- /dev/null +++ b/cpp/test/neighbors/sparse_brute_force.cu @@ -0,0 +1,175 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" + +#include +#include +#include + +#include +#include + +namespace cuvs { +namespace neighbors { + +using namespace raft; +using namespace raft::sparse; + +template +struct SparseKNNInputs { + value_idx n_cols; + + std::vector indptr_h; + std::vector indices_h; + std::vector data_h; + + std::vector out_dists_ref_h; + std::vector out_indices_ref_h; + + int k; + + int batch_size_index = 2; + int batch_size_query = 2; + + cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2SqrtExpanded; +}; + +template +::std::ostream& operator<<(::std::ostream& os, const SparseKNNInputs& dims) +{ + return os; +} + +template +class SparseKNNTest : public ::testing::TestWithParam> { + public: + SparseKNNTest() + : params(::testing::TestWithParam>::GetParam()), + indptr(0, resource::get_cuda_stream(handle)), + indices(0, resource::get_cuda_stream(handle)), + data(0, resource::get_cuda_stream(handle)), + out_indices(0, resource::get_cuda_stream(handle)), + out_dists(0, resource::get_cuda_stream(handle)), + out_indices_ref(0, resource::get_cuda_stream(handle)), + out_dists_ref(0, resource::get_cuda_stream(handle)) + { + } + + protected: + void SetUp() override + { + n_rows = params.indptr_h.size() - 1; + nnz = params.indices_h.size(); + k = params.k; + + make_data(); + + auto index_structure = + raft::make_device_compressed_structure_view( + indptr.data(), indices.data(), n_rows, params.n_cols, nnz); + auto index_csr = raft::make_device_csr_matrix_view(data.data(), index_structure); + + auto index = cuvs::neighbors::brute_force::build(handle, index_csr, params.metric); + + cuvs::neighbors::brute_force::sparse_search_params search_params; + search_params.batch_size_index = params.batch_size_index; + search_params.batch_size_query = params.batch_size_query; + + cuvs::neighbors::brute_force::search( + handle, + search_params, + index, + index_csr, + raft::make_device_matrix_view(out_indices.data(), n_rows, k), + raft::make_device_matrix_view(out_dists.data(), n_rows, k)); + + RAFT_CUDA_TRY(cudaStreamSynchronize(resource::get_cuda_stream(handle))); + } + + void compare() + { + ASSERT_TRUE(devArrMatch( + out_dists_ref.data(), out_dists.data(), n_rows * k, CompareApprox(1e-4))); + ASSERT_TRUE( + devArrMatch(out_indices_ref.data(), out_indices.data(), n_rows * k, Compare())); + } + + protected: + void make_data() + { + std::vector indptr_h = params.indptr_h; + std::vector indices_h = params.indices_h; + std::vector data_h = params.data_h; + + auto stream = resource::get_cuda_stream(handle); + indptr.resize(indptr_h.size(), stream); + indices.resize(indices_h.size(), stream); + data.resize(data_h.size(), stream); + + update_device(indptr.data(), indptr_h.data(), indptr_h.size(), stream); + update_device(indices.data(), indices_h.data(), indices_h.size(), stream); + update_device(data.data(), data_h.data(), data_h.size(), stream); + + std::vector out_dists_ref_h = params.out_dists_ref_h; + std::vector out_indices_ref_h = params.out_indices_ref_h; + + out_indices_ref.resize(out_indices_ref_h.size(), stream); + out_dists_ref.resize(out_dists_ref_h.size(), stream); + + update_device( + out_indices_ref.data(), out_indices_ref_h.data(), out_indices_ref_h.size(), stream); + update_device(out_dists_ref.data(), out_dists_ref_h.data(), out_dists_ref_h.size(), stream); + + out_dists.resize(n_rows * k, stream); + out_indices.resize(n_rows * k, stream); + } + + raft::resources handle; + + int n_rows, nnz, k; + + // input data + rmm::device_uvector indptr, indices; + rmm::device_uvector data; + + // output data + rmm::device_uvector out_indices; + rmm::device_uvector out_dists; + + rmm::device_uvector out_indices_ref; + rmm::device_uvector out_dists_ref; + + SparseKNNInputs params; +}; + +const std::vector> inputs_i32_f = { + {9, // ncols + {0, 2, 4, 6, 8}, // indptr + {0, 4, 0, 3, 0, 2, 0, 8}, // indices + {0.0f, 1.0f, 5.0f, 6.0f, 5.0f, 6.0f, 0.0f, 1.0f}, // data + {0, 1.41421, 0, 7.87401, 0, 7.87401, 0, 1.41421}, // dists + {0, 3, 1, 0, 2, 0, 3, 0}, // inds + 2, + 2, + 2, + cuvs::distance::DistanceType::L2SqrtExpanded}}; +typedef SparseKNNTest SparseKNNTestF; +TEST_P(SparseKNNTestF, Result) { compare(); } +INSTANTIATE_TEST_CASE_P(SparseKNNTest, SparseKNNTestF, ::testing::ValuesIn(inputs_i32_f)); + +}; // end namespace neighbors +}; // end namespace cuvs