Skip to content

Commit

Permalink
Migrate sparse knn and distances code from raft (#457)
Browse files Browse the repository at this point in the history
Authors:
  - Ben Frederickson (https://github.com/benfred)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #457
  • Loading branch information
benfred authored Nov 20, 2024
1 parent f127b06 commit 06afd5b
Show file tree
Hide file tree
Showing 22 changed files with 4,548 additions and 1 deletion.
2 changes: 2 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ if(BUILD_SHARED_LIBS)
src/distance/detail/fused_distance_nn.cu
src/distance/distance.cu
src/distance/pairwise_distance.cu
src/distance/sparse_distance.cu
src/neighbors/brute_force.cu
src/neighbors/cagra_build_float.cu
src/neighbors/cagra_build_half.cu
Expand Down Expand Up @@ -449,6 +450,7 @@ if(BUILD_SHARED_LIBS)
src/neighbors/refine/detail/refine_host_int8_t_float.cpp
src/neighbors/refine/detail/refine_host_uint8_t_float.cpp
src/neighbors/sample_filter.cu
src/neighbors/sparse_brute_force.cu
src/neighbors/vamana_build_float.cu
src/neighbors/vamana_build_uint8.cu
src/neighbors/vamana_build_int8.cu
Expand Down
81 changes: 81 additions & 0 deletions cpp/include/cuvs/distance/distance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <cstdint>
#include <cuda_fp16.h>
#include <raft/core/device_csr_matrix.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/resources.hpp>

Expand Down Expand Up @@ -331,6 +332,86 @@ void pairwise_distance(
cuvs::distance::DistanceType metric,
float metric_arg = 2.0f);

/**
* @brief Compute sparse pairwise distances between x and y, using the provided
* input configuration and distance function.
*
* @code{.cpp}
* #include <raft/core/device_resources.hpp>
* #include <raft/core/device_csr_matrix.hpp>
* #include <raft/core/device_mdspan.hpp>
*
* int x_n_rows = 100000;
* int y_n_rows = 50000;
* int n_cols = 10000;
*
* raft::device_resources handle;
* auto x = raft::make_device_csr_matrix<float>(handle, x_n_rows, n_cols);
* auto y = raft::make_device_csr_matrix<float>(handle, y_n_rows, n_cols);
*
* ...
* // populate data
* ...
*
* auto out = raft::make_device_matrix<float>(handle, x_nrows, y_nrows);
* auto metric = cuvs::distance::DistanceType::L2Expanded;
* raft::sparse::distance::pairwise_distance(handle, x.view(), y.view(), out, metric);
* @endcode
*
* @param[in] handle raft::resources
* @param[in] x raft::device_csr_matrix_view
* @param[in] y raft::device_csr_matrix_view
* @param[out] dist raft::device_matrix_view dense matrix
* @param[in] metric distance metric to use
* @param[in] metric_arg metric argument (used for Minkowski distance)
*/
void pairwise_distance(raft::resources const& handle,
raft::device_csr_matrix_view<const float, int, int, int> x,
raft::device_csr_matrix_view<const float, int, int, int> y,
raft::device_matrix_view<float, int, raft::row_major> dist,
cuvs::distance::DistanceType metric,
float metric_arg = 2.0f);

/**
* @brief Compute sparse pairwise distances between x and y, using the provided
* input configuration and distance function.
*
* @code{.cpp}
* #include <raft/core/device_resources.hpp>
* #include <raft/core/device_csr_matrix.hpp>
* #include <raft/core/device_mdspan.hpp>
*
* int x_n_rows = 100000;
* int y_n_rows = 50000;
* int n_cols = 10000;
*
* raft::device_resources handle;
* auto x = raft::make_device_csr_matrix<double>(handle, x_n_rows, n_cols);
* auto y = raft::make_device_csr_matrix<double>(handle, y_n_rows, n_cols);
*
* ...
* // populate data
* ...
*
* auto out = raft::make_device_matrix<double>(handle, x_nrows, y_nrows);
* auto metric = cuvs::distance::DistanceType::L2Expanded;
* raft::sparse::distance::pairwise_distance(handle, x.view(), y.view(), out, metric);
* @endcode
*
* @param[in] handle raft::resources
* @param[in] x raft::device_csr_matrix_view
* @param[in] y raft::device_csr_matrix_view
* @param[out] dist raft::device_matrix_view dense matrix
* @param[in] metric distance metric to use
* @param[in] metric_arg metric argument (used for Minkowski distance)
*/
void pairwise_distance(raft::resources const& handle,
raft::device_csr_matrix_view<const double, int, int, int> x,
raft::device_csr_matrix_view<const double, int, int, int> y,
raft::device_matrix_view<double, int, raft::row_major> dist,
cuvs::distance::DistanceType metric,
float metric_arg = 2.0f);

/** @} */ // end group pairwise_distance_runtime

}; // namespace cuvs::distance
104 changes: 104 additions & 0 deletions cpp/include/cuvs/neighbors/brute_force.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "common.hpp"
#include <cuvs/neighbors/common.hpp>
#include <raft/core/device_csr_matrix.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/handle.hpp>
Expand Down Expand Up @@ -375,4 +376,107 @@ void search(raft::resources const& handle,
* @}
*/

/**
* @defgroup sparse_bruteforce_cpp_index Sparse Brute Force index
* @{
*/
/**
* @brief Sparse Brute Force index.
*
* @tparam T Data element type
* @tparam IdxT Index element type
*/
template <typename T, typename IdxT>
struct sparse_index {
public:
sparse_index(const sparse_index&) = delete;
sparse_index(sparse_index&&) = default;
sparse_index& operator=(const sparse_index&) = delete;
sparse_index& operator=(sparse_index&&) = default;
~sparse_index() = default;

/** Construct a sparse brute force sparse_index from dataset */
sparse_index(raft::resources const& res,
raft::device_csr_matrix_view<const T, IdxT, IdxT, IdxT> dataset,
cuvs::distance::DistanceType metric,
T metric_arg);

/** Distance metric used for retrieval */
cuvs::distance::DistanceType metric() const noexcept { return metric_; }

/** Metric argument */
T metric_arg() const noexcept { return metric_arg_; }

raft::device_csr_matrix_view<const T, IdxT, IdxT, IdxT> dataset() const noexcept
{
return dataset_;
}

private:
raft::device_csr_matrix_view<const T, IdxT, IdxT, IdxT> dataset_;
cuvs::distance::DistanceType metric_;
T metric_arg_;
};
/**
* @}
*/

/**
* @defgroup sparse_bruteforce_cpp_index_build Sparse Brute Force index build
* @{
*/

/*
* @brief Build the Sparse index from the dataset
*
* Usage example:
* @code{.cpp}
* using namespace cuvs::neighbors;
* // create and fill the index from a CSR dataset
* auto index = brute_force::build(handle, dataset, metric);
* @endcode
*
* @param[in] handle
* @param[in] dataset A sparse CSR matrix in device memory to search against
* @param[in] metric cuvs::distance::DistanceType
* @param[in] metric_arg metric argument
*
* @return the constructed Sparse brute-force index
*/
auto build(raft::resources const& handle,
raft::device_csr_matrix_view<const float, int, int, int> dataset,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded,
float metric_arg = 0) -> cuvs::neighbors::brute_force::sparse_index<float, int>;
/**
* @}
*/

/**
* @defgroup sparse_bruteforce_cpp_index_search Sparse Brute Force index search
* @{
*/
struct sparse_search_params {
int batch_size_index = 2 << 14;
int batch_size_query = 2 << 14;
};

/*
* @brief Search the sparse bruteforce index for nearest neighbors
*
* @param[in] handle
* @param[in] index Sparse brute-force constructed index
* @param[in] queries a sparse CSR matrix on the device to query
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k]
*/
void search(raft::resources const& handle,
const sparse_search_params& params,
const sparse_index<float, int>& index,
raft::device_csr_matrix_view<const float, int, int, int> dataset,
raft::device_matrix_view<int, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances);
/**
* @}
*/
} // namespace cuvs::neighbors::brute_force
Loading

0 comments on commit 06afd5b

Please sign in to comment.