Skip to content

Commit

Permalink
Merge branch 'branch-24.06' into update-fmt-and-spdlog
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb authored May 3, 2024
2 parents b0f375d + 3406569 commit 2ca5cbe
Show file tree
Hide file tree
Showing 27 changed files with 387 additions and 129 deletions.
2 changes: 2 additions & 0 deletions conda/recipes/libraft/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ outputs:
{% if cuda_major != "11" %}
- cuda-cudart-dev
{% endif %}
- librmm ={{ minor_version }}
run:
- {{ pin_compatible('cuda-version', max_pin='x', min_pin='x') }}
{% if cuda_major == "11" %}
Expand Down Expand Up @@ -93,6 +94,7 @@ outputs:
requirements:
host:
- cuda-version ={{ cuda_version }}
- librmm ={{ minor_version }}
run:
- {{ pin_subpackage('libraft-headers-only', exact=True) }}
- librmm ={{ minor_version }}
Expand Down
8 changes: 5 additions & 3 deletions cpp/include/raft/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <raft/core/host_device_accessor.hpp>
#include <raft/core/mdspan.hpp>
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/neighbors/cagra_types.hpp>
#include <raft/neighbors/dataset.hpp>

Expand All @@ -48,13 +49,14 @@ namespace raft::neighbors::cagra {
*
* The following distance metrics are supported:
* - L2Expanded
* - InnerProduct
*
* Usage example:
* @code{.cpp}
* using namespace raft::neighbors;
* // use default index parameters
* ivf_pq::index_params build_params;
* ivf_pq::search_params search_params
* // use default index parameters based on shape of the dataset
* ivf_pq::index_params build_params = ivf_pq::index_params::from_dataset(dataset);
* ivf_pq::search_params search_params;
* auto knn_graph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), 128);
* // create knn graph
* cagra::build_knn_graph(res, dataset, knn_graph.view(), 2, build_params, search_params);
Expand Down
32 changes: 16 additions & 16 deletions cpp/include/raft/neighbors/cagra_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ namespace raft::neighbors::cagra {
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <raft/neighbors/cagra_serialize.hpp>
* #include <raft/neighbors/cagra_serialize.cuh>
*
* raft::resources handle;
*
* // create an output stream
* std::ostream os(std::cout.rdbuf());
* // create an index with `auto index = raft::cagra::build(...);`
* raft::cagra::serialize(handle, os, index);
* // create an index with `auto index = raft::neighbors::cagra::build(...);`
* raft::neighbors::cagra::serialize(handle, os, index);
* @endcode
*
* @tparam T data element type
Expand Down Expand Up @@ -67,14 +67,14 @@ void serialize(raft::resources const& handle,
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <raft/neighbors/cagra_serialize.hpp>
* #include <raft/neighbors/cagra_serialize.cuh>
*
* raft::resources handle;
*
* // create a string with a filepath
* std::string filename("/path/to/index");
* // create an index with `auto index = raft::cagra::build(...);`
* raft::cagra::serialize(handle, filename, index);
* // create an index with `auto index = raft::neighbors::cagra::build(...);`
* raft::neighbors::cagra::serialize(handle, filename, index);
* @endcode
*
* @tparam T data element type
Expand Down Expand Up @@ -102,14 +102,14 @@ void serialize(raft::resources const& handle,
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <raft/neighbors/cagra_serialize.hpp>
* #include <raft/neighbors/cagra_serialize.cuh>
*
* raft::resources handle;
*
* // create an output stream
* std::ostream os(std::cout.rdbuf());
* // create an index with `auto index = raft::cagra::build(...);`
* raft::cagra::serialize_to_hnswlib(handle, os, index);
* // create an index with `auto index = raft::neighbors::cagra::build(...);`
* raft::neighbors::cagra::serialize_to_hnswlib(handle, os, index);
* @endcode
*
* @tparam T data element type
Expand All @@ -135,14 +135,14 @@ void serialize_to_hnswlib(raft::resources const& handle,
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <raft/neighbors/cagra_serialize.hpp>
* #include <raft/neighbors/cagra_serialize.cuh>
*
* raft::resources handle;
*
* // create a string with a filepath
* std::string filename("/path/to/index");
* // create an index with `auto index = raft::cagra::build(...);`
* raft::cagra::serialize_to_hnswlib(handle, filename, index);
* // create an index with `auto index = raft::neighbors::cagra::build(...);`
* raft::neighbors::cagra::serialize_to_hnswlib(handle, filename, index);
* @endcode
*
* @tparam T data element type
Expand All @@ -168,15 +168,15 @@ void serialize_to_hnswlib(raft::resources const& handle,
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <raft/neighbors/cagra_serialize.hpp>
* #include <raft/neighbors/cagra_serialize.cuh>
*
* raft::resources handle;
*
* // create an input stream
* std::istream is(std::cin.rdbuf());
* using T = float; // data element type
* using IdxT = int; // type of the index
* auto index = raft::cagra::deserialize<T, IdxT>(handle, is);
* auto index = raft::neighbors::cagra::deserialize<T, IdxT>(handle, is);
* @endcode
*
* @tparam T data element type
Expand All @@ -200,15 +200,15 @@ index<T, IdxT> deserialize(raft::resources const& handle, std::istream& is)
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <raft/neighbors/cagra_serialize.hpp>
* #include <raft/neighbors/cagra_serialize.cuh>
*
* raft::resources handle;
*
* // create a string with a filepath
* std::string filename("/path/to/index");
* using T = float; // data element type
* using IdxT = int; // type of the index
* auto index = raft::cagra::deserialize<T, IdxT>(handle, filename);
* auto index = raft::neighbors::cagra::deserialize<T, IdxT>(handle, filename);
* @endcode
*
* @tparam T data element type
Expand Down
22 changes: 10 additions & 12 deletions cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/error.hpp>
#include <raft/core/host_device_accessor.hpp>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/host_mdspan.hpp>
Expand Down Expand Up @@ -50,24 +51,17 @@ void build_knn_graph(raft::resources const& res,
std::optional<ivf_pq::index_params> build_params = std::nullopt,
std::optional<ivf_pq::search_params> search_params = std::nullopt)
{
RAFT_EXPECTS(!build_params || build_params->metric == distance::DistanceType::L2Expanded,
"Currently only L2Expanded metric is supported");
RAFT_EXPECTS(!build_params || build_params->metric == distance::DistanceType::L2Expanded ||
build_params->metric == distance::DistanceType::InnerProduct,
"Currently only L2Expanded or InnerProduct metric are supported");

uint32_t node_degree = knn_graph.extent(1);
common::nvtx::range<common::nvtx::domain::raft> fun_scope("cagra::build_graph(%zu, %zu, %u)",
size_t(dataset.extent(0)),
size_t(dataset.extent(1)),
node_degree);

if (!build_params) {
build_params = ivf_pq::index_params{};
build_params->n_lists = dataset.extent(0) < 4 * 2500 ? 4 : (uint32_t)(dataset.extent(0) / 2500);
build_params->pq_dim = raft::Pow2<8>::roundUp(dataset.extent(1) / 2);
build_params->pq_bits = 8;
build_params->kmeans_trainset_fraction = dataset.extent(0) < 10000 ? 1 : 10;
build_params->kmeans_n_iters = 25;
build_params->add_data_on_build = true;
}
if (!build_params) { build_params = ivf_pq::index_params::from_dataset(dataset); }

// Make model name
const std::string model_name = [&]() {
Expand Down Expand Up @@ -324,8 +318,10 @@ index<T, IdxT> build(

if (params.build_algo == graph_build_algo::IVF_PQ) {
build_knn_graph(res, dataset, knn_graph->view(), refine_rate, pq_build_params, search_params);

} else {
RAFT_EXPECTS(
params.metric == raft::distance::DistanceType::L2Expanded,
"L2Expanded is the only distance metrics supported for CAGRA build with nn_descent");
// Use nn-descent to build CAGRA knn graph
if (!nn_descent_params) {
nn_descent_params = experimental::nn_descent::index_params();
Expand All @@ -348,6 +344,8 @@ index<T, IdxT> build(
// Construct an index from dataset and optimized knn graph.
if (construct_index_with_dataset) {
if (params.compression.has_value()) {
RAFT_EXPECTS(params.metric == raft::distance::DistanceType::L2Expanded,
"VPQ compression is only supported with L2Expanded distance mertric");
index<T, IdxT> idx(res, params.metric);
idx.update_graph(res, raft::make_const_mdspan(cagra_graph.view()));
idx.update_dataset(
Expand Down
35 changes: 26 additions & 9 deletions cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <raft/core/nvtx.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/neighbors/cagra_types.hpp>
#include <raft/neighbors/detail/ivf_common.cuh>
#include <raft/neighbors/detail/ivf_pq_search.cuh>
Expand Down Expand Up @@ -87,7 +88,8 @@ void search_main_core(
raft::device_matrix_view<const typename DatasetDescriptorT::DATA_T, int64_t, row_major> queries,
raft::device_matrix_view<typename DatasetDescriptorT::INDEX_T, int64_t, row_major> neighbors,
raft::device_matrix_view<typename DatasetDescriptorT::DISTANCE_T, int64_t, row_major> distances,
CagraSampleFilterT sample_filter = CagraSampleFilterT())
CagraSampleFilterT sample_filter = CagraSampleFilterT(),
raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded)
{
RAFT_LOG_DEBUG("# dataset size = %lu, dim = %lu\n",
static_cast<size_t>(dataset_desc.size),
Expand All @@ -112,7 +114,7 @@ void search_main_core(
using CagraSampleFilterT_s = typename CagraSampleFilterT_Selector<CagraSampleFilterT>::type;
std::unique_ptr<search_plan_impl<DatasetDescriptorT, CagraSampleFilterT_s>> plan =
factory<DatasetDescriptorT, CagraSampleFilterT_s>::create(
res, params, dataset_desc.dim, graph.extent(1), topk);
res, params, dataset_desc.dim, graph.extent(1), topk, metric);

plan->check(topk);

Expand Down Expand Up @@ -163,7 +165,8 @@ void launch_vpq_search_main_core(
raft::device_matrix_view<const T, int64_t, row_major> queries,
raft::device_matrix_view<InternalIdxT, int64_t, row_major> neighbors,
raft::device_matrix_view<DistanceT, int64_t, row_major> distances,
CagraSampleFilterT sample_filter)
CagraSampleFilterT sample_filter,
const raft::distance::DistanceType metric)
{
RAFT_EXPECTS(vpq_dset->pq_bits() == 8, "Only pq_bits = 8 is supported for now");
RAFT_EXPECTS(vpq_dset->pq_len() == 2 || vpq_dset->pq_len() == 4,
Expand Down Expand Up @@ -192,7 +195,7 @@ void launch_vpq_search_main_core(
size_t(vpq_dset->n_rows()),
vpq_dset->dim());
search_main_core(
res, params, dataset_desc, graph, queries, neighbors, distances, sample_filter);
res, params, dataset_desc, graph, queries, neighbors, distances, sample_filter, metric);
} else if (vpq_dset->pq_len() == 4) {
using dataset_desc_t = cagra_q_dataset_descriptor_t<T,
DatasetT,
Expand All @@ -210,7 +213,7 @@ void launch_vpq_search_main_core(
size_t(vpq_dset->n_rows()),
vpq_dset->dim());
search_main_core(
res, params, dataset_desc, graph, queries, neighbors, distances, sample_filter);
res, params, dataset_desc, graph, queries, neighbors, distances, sample_filter, metric);
} else {
RAFT_FAIL("Subspace dimension must be 2 or 4");
}
Expand Down Expand Up @@ -268,17 +271,31 @@ void search_main(raft::resources const& res,
strided_dset->n_rows(),
strided_dset->dim(),
strided_dset->stride());

search_main_core<dataset_desc_t, CagraSampleFilterT>(
res, params, dataset_desc, graph_internal, queries, neighbors, distances, sample_filter);
search_main_core<dataset_desc_t, CagraSampleFilterT>(res,
params,
dataset_desc,
graph_internal,
queries,
neighbors,
distances,
sample_filter,
index.metric());
} else if (auto* vpq_dset = dynamic_cast<const vpq_dataset<float, ds_idx_type>*>(&index.data());
vpq_dset != nullptr) {
// Search using a compressed dataset
RAFT_FAIL("FP32 VPQ dataset support is coming soon");
} else if (auto* vpq_dset = dynamic_cast<const vpq_dataset<half, ds_idx_type>*>(&index.data());
vpq_dset != nullptr) {
launch_vpq_search_main_core<T, half, ds_idx_type, InternalIdxT, DistanceT, CagraSampleFilterT>(
res, vpq_dset, params, graph_internal, queries, neighbors, distances, sample_filter);
res,
vpq_dset,
params,
graph_internal,
queries,
neighbors,
distances,
sample_filter,
index.metric());
} else if (auto* empty_dset = dynamic_cast<const empty_dataset<ds_idx_type>*>(&index.data());
empty_dset != nullptr) {
// Forgot to add a dataset.
Expand Down
Loading

0 comments on commit 2ca5cbe

Please sign in to comment.