Skip to content

Commit

Permalink
InnerProduct Distance Metric for CAGRA search (#2260)
Browse files Browse the repository at this point in the history
`InnerProduct` Distance Metric for CAGRA search. InnerProduct in graph building is supported using IVF-PQ for building the graph. NNDescent does not currently support any other metric except L2Expanded.

Authors:
  - Tarang Jain (https://github.com/tarang-jain)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - tsuki (https://github.com/enp1s0)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #2260
  • Loading branch information
tarang-jain authored Apr 30, 2024
1 parent d4d92ce commit e720de7
Show file tree
Hide file tree
Showing 21 changed files with 336 additions and 111 deletions.
8 changes: 5 additions & 3 deletions cpp/include/raft/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <raft/core/host_device_accessor.hpp>
#include <raft/core/mdspan.hpp>
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/neighbors/cagra_types.hpp>
#include <raft/neighbors/dataset.hpp>

Expand All @@ -48,13 +49,14 @@ namespace raft::neighbors::cagra {
*
* The following distance metrics are supported:
* - L2Expanded
* - InnerProduct
*
* Usage example:
* @code{.cpp}
* using namespace raft::neighbors;
* // use default index parameters
* ivf_pq::index_params build_params;
* ivf_pq::search_params search_params
* // use default index parameters based on shape of the dataset
* ivf_pq::index_params build_params = ivf_pq::index_params::from_dataset(dataset);
* ivf_pq::search_params search_params;
* auto knn_graph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), 128);
* // create knn graph
* cagra::build_knn_graph(res, dataset, knn_graph.view(), 2, build_params, search_params);
Expand Down
22 changes: 10 additions & 12 deletions cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/error.hpp>
#include <raft/core/host_device_accessor.hpp>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/host_mdspan.hpp>
Expand Down Expand Up @@ -50,24 +51,17 @@ void build_knn_graph(raft::resources const& res,
std::optional<ivf_pq::index_params> build_params = std::nullopt,
std::optional<ivf_pq::search_params> search_params = std::nullopt)
{
RAFT_EXPECTS(!build_params || build_params->metric == distance::DistanceType::L2Expanded,
"Currently only L2Expanded metric is supported");
RAFT_EXPECTS(!build_params || build_params->metric == distance::DistanceType::L2Expanded ||
build_params->metric == distance::DistanceType::InnerProduct,
"Currently only L2Expanded or InnerProduct metric are supported");

uint32_t node_degree = knn_graph.extent(1);
common::nvtx::range<common::nvtx::domain::raft> fun_scope("cagra::build_graph(%zu, %zu, %u)",
size_t(dataset.extent(0)),
size_t(dataset.extent(1)),
node_degree);

if (!build_params) {
build_params = ivf_pq::index_params{};
build_params->n_lists = dataset.extent(0) < 4 * 2500 ? 4 : (uint32_t)(dataset.extent(0) / 2500);
build_params->pq_dim = raft::Pow2<8>::roundUp(dataset.extent(1) / 2);
build_params->pq_bits = 8;
build_params->kmeans_trainset_fraction = dataset.extent(0) < 10000 ? 1 : 10;
build_params->kmeans_n_iters = 25;
build_params->add_data_on_build = true;
}
if (!build_params) { build_params = ivf_pq::index_params::from_dataset(dataset); }

// Make model name
const std::string model_name = [&]() {
Expand Down Expand Up @@ -324,8 +318,10 @@ index<T, IdxT> build(

if (params.build_algo == graph_build_algo::IVF_PQ) {
build_knn_graph(res, dataset, knn_graph->view(), refine_rate, pq_build_params, search_params);

} else {
RAFT_EXPECTS(
params.metric == raft::distance::DistanceType::L2Expanded,
"L2Expanded is the only distance metrics supported for CAGRA build with nn_descent");
// Use nn-descent to build CAGRA knn graph
if (!nn_descent_params) {
nn_descent_params = experimental::nn_descent::index_params();
Expand All @@ -348,6 +344,8 @@ index<T, IdxT> build(
// Construct an index from dataset and optimized knn graph.
if (construct_index_with_dataset) {
if (params.compression.has_value()) {
RAFT_EXPECTS(params.metric == raft::distance::DistanceType::L2Expanded,
"VPQ compression is only supported with L2Expanded distance mertric");
index<T, IdxT> idx(res, params.metric);
idx.update_graph(res, raft::make_const_mdspan(cagra_graph.view()));
idx.update_dataset(
Expand Down
35 changes: 26 additions & 9 deletions cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <raft/core/nvtx.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/neighbors/cagra_types.hpp>
#include <raft/neighbors/detail/ivf_common.cuh>
#include <raft/neighbors/detail/ivf_pq_search.cuh>
Expand Down Expand Up @@ -87,7 +88,8 @@ void search_main_core(
raft::device_matrix_view<const typename DatasetDescriptorT::DATA_T, int64_t, row_major> queries,
raft::device_matrix_view<typename DatasetDescriptorT::INDEX_T, int64_t, row_major> neighbors,
raft::device_matrix_view<typename DatasetDescriptorT::DISTANCE_T, int64_t, row_major> distances,
CagraSampleFilterT sample_filter = CagraSampleFilterT())
CagraSampleFilterT sample_filter = CagraSampleFilterT(),
raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded)
{
RAFT_LOG_DEBUG("# dataset size = %lu, dim = %lu\n",
static_cast<size_t>(dataset_desc.size),
Expand All @@ -112,7 +114,7 @@ void search_main_core(
using CagraSampleFilterT_s = typename CagraSampleFilterT_Selector<CagraSampleFilterT>::type;
std::unique_ptr<search_plan_impl<DatasetDescriptorT, CagraSampleFilterT_s>> plan =
factory<DatasetDescriptorT, CagraSampleFilterT_s>::create(
res, params, dataset_desc.dim, graph.extent(1), topk);
res, params, dataset_desc.dim, graph.extent(1), topk, metric);

plan->check(topk);

Expand Down Expand Up @@ -163,7 +165,8 @@ void launch_vpq_search_main_core(
raft::device_matrix_view<const T, int64_t, row_major> queries,
raft::device_matrix_view<InternalIdxT, int64_t, row_major> neighbors,
raft::device_matrix_view<DistanceT, int64_t, row_major> distances,
CagraSampleFilterT sample_filter)
CagraSampleFilterT sample_filter,
const raft::distance::DistanceType metric)
{
RAFT_EXPECTS(vpq_dset->pq_bits() == 8, "Only pq_bits = 8 is supported for now");
RAFT_EXPECTS(vpq_dset->pq_len() == 2 || vpq_dset->pq_len() == 4,
Expand Down Expand Up @@ -192,7 +195,7 @@ void launch_vpq_search_main_core(
size_t(vpq_dset->n_rows()),
vpq_dset->dim());
search_main_core(
res, params, dataset_desc, graph, queries, neighbors, distances, sample_filter);
res, params, dataset_desc, graph, queries, neighbors, distances, sample_filter, metric);
} else if (vpq_dset->pq_len() == 4) {
using dataset_desc_t = cagra_q_dataset_descriptor_t<T,
DatasetT,
Expand All @@ -210,7 +213,7 @@ void launch_vpq_search_main_core(
size_t(vpq_dset->n_rows()),
vpq_dset->dim());
search_main_core(
res, params, dataset_desc, graph, queries, neighbors, distances, sample_filter);
res, params, dataset_desc, graph, queries, neighbors, distances, sample_filter, metric);
} else {
RAFT_FAIL("Subspace dimension must be 2 or 4");
}
Expand Down Expand Up @@ -268,17 +271,31 @@ void search_main(raft::resources const& res,
strided_dset->n_rows(),
strided_dset->dim(),
strided_dset->stride());

search_main_core<dataset_desc_t, CagraSampleFilterT>(
res, params, dataset_desc, graph_internal, queries, neighbors, distances, sample_filter);
search_main_core<dataset_desc_t, CagraSampleFilterT>(res,
params,
dataset_desc,
graph_internal,
queries,
neighbors,
distances,
sample_filter,
index.metric());
} else if (auto* vpq_dset = dynamic_cast<const vpq_dataset<float, ds_idx_type>*>(&index.data());
vpq_dset != nullptr) {
// Search using a compressed dataset
RAFT_FAIL("FP32 VPQ dataset support is coming soon");
} else if (auto* vpq_dset = dynamic_cast<const vpq_dataset<half, ds_idx_type>*>(&index.data());
vpq_dset != nullptr) {
launch_vpq_search_main_core<T, half, ds_idx_type, InternalIdxT, DistanceT, CagraSampleFilterT>(
res, vpq_dset, params, graph_internal, queries, neighbors, distances, sample_filter);
res,
vpq_dset,
params,
graph_internal,
queries,
neighbors,
distances,
sample_filter,
index.metric());
} else if (auto* empty_dset = dynamic_cast<const empty_dataset<ds_idx_type>*>(&index.data());
empty_dset != nullptr) {
// Forgot to add a dataset.
Expand Down
65 changes: 56 additions & 9 deletions cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include "hashmap.hpp"
#include "utils.hpp"

#include <raft/core/operators.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/spatial/knn/detail/ann_utils.cuh>
#include <raft/util/vectorized.cuh>

Expand Down Expand Up @@ -54,6 +56,7 @@ _RAFT_DEVICE void compute_distance_to_random_nodes(
const uint32_t num_seeds,
INDEX_T* const visited_hash_ptr,
const uint32_t hash_bitlen,
const raft::distance::DistanceType metric,
const uint32_t block_id = 0,
const uint32_t num_blocks = 1)
{
Expand All @@ -78,8 +81,22 @@ _RAFT_DEVICE void compute_distance_to_random_nodes(
}
}

const auto norm2 = dataset_desc.template compute_similarity<DATASET_BLOCK_DIM, TEAM_SIZE>(
query_buffer, seed_index, valid_i);
DISTANCE_T norm2;
switch (metric) {
case raft::distance::L2Expanded:
norm2 = dataset_desc.template compute_similarity<DATASET_BLOCK_DIM,
TEAM_SIZE,
raft::distance::L2Expanded>(
query_buffer, seed_index, valid_i);
break;
case raft::distance::InnerProduct:
norm2 = dataset_desc.template compute_similarity<DATASET_BLOCK_DIM,
TEAM_SIZE,
raft::distance::InnerProduct>(
query_buffer, seed_index, valid_i);
break;
default: break;
}

if (valid_i && (norm2 < best_norm2_team_local)) {
best_norm2_team_local = norm2;
Expand Down Expand Up @@ -121,7 +138,8 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(
const std::uint32_t hash_bitlen,
const INDEX_T* const parent_indices,
const INDEX_T* const internal_topk_list,
const std::uint32_t search_width)
const std::uint32_t search_width,
const raft::distance::DistanceType metric)
{
constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask<INDEX_T>::value;
const INDEX_T invalid_index = utils::get_max_value<INDEX_T>();
Expand Down Expand Up @@ -153,8 +171,22 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(
INDEX_T child_id = invalid_index;
if (valid_i) { child_id = result_child_indices_ptr[i]; }

const auto norm2 = dataset_desc.template compute_similarity<DATASET_BLOCK_DIM, TEAM_SIZE>(
query_buffer, child_id, child_id != invalid_index);
DISTANCE_T norm2;
switch (metric) {
case raft::distance::L2Expanded:
norm2 =
dataset_desc
.template compute_similarity<DATASET_BLOCK_DIM, TEAM_SIZE, raft::distance::L2Expanded>(
query_buffer, child_id, child_id != invalid_index);
break;
case raft::distance::InnerProduct:
norm2 = dataset_desc.template compute_similarity<DATASET_BLOCK_DIM,
TEAM_SIZE,
raft::distance::InnerProduct>(
query_buffer, child_id, child_id != invalid_index);
break;
default: break;
}

// Store the distance
const unsigned lane_id = threadIdx.x % TEAM_SIZE;
Expand Down Expand Up @@ -220,7 +252,22 @@ struct standard_dataset_descriptor_t
}
}

template <uint32_t DATASET_BLOCK_DIM, uint32_t TEAM_SIZE>
template <typename T, raft::distance::DistanceType METRIC>
std::enable_if_t<METRIC == raft::distance::DistanceType::L2Expanded, T> __device__
dist_op(T a, T b) const
{
T diff = a - b;
return diff * diff;
}

template <typename T, raft::distance::DistanceType METRIC>
std::enable_if_t<METRIC == raft::distance::DistanceType::InnerProduct, T> __device__
dist_op(T a, T b) const
{
return -a * b;
}

template <uint32_t DATASET_BLOCK_DIM, uint32_t TEAM_SIZE, raft::distance::DistanceType METRIC>
__device__ DISTANCE_T compute_similarity(const QUERY_T* const query_ptr,
const INDEX_T dataset_i,
const bool valid) const
Expand Down Expand Up @@ -252,9 +299,9 @@ struct standard_dataset_descriptor_t
// because:
// - Above the last element (dataset_dim-1), the query array is filled with zeros.
// - The data buffer has to be also padded with zeros.
DISTANCE_T diff = query_ptr[device::swizzling(kv)];
diff -= spatial::knn::detail::utils::mapping<float>{}(dl_buff[e].val.data[v]);
norm2 += diff * diff;
DISTANCE_T d = query_ptr[device::swizzling(kv)];
norm2 += dist_op<DISTANCE_T, METRIC>(
d, spatial::knn::detail::utils::mapping<float>{}(dl_buff[e].val.data[v]));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "compute_distance.hpp"

#include <raft/distance/distance_types.hpp>
#include <raft/util/integer_utils.hpp>

namespace raft::neighbors::cagra::detail {
Expand Down Expand Up @@ -112,7 +113,7 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
}
}

template <uint32_t DATASET_BLOCK_DIM, uint32_t TEAM_SIZE>
template <uint32_t DATASET_BLOCK_DIM, uint32_t TEAM_SIZE, raft::distance::DistanceType METRIC>
__device__ DISTANCE_T compute_similarity(const QUERY_T* const query_ptr,
const INDEX_T node_id,
const bool valid) const
Expand Down Expand Up @@ -227,4 +228,4 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
}
};

} // namespace raft::neighbors::cagra::detail
} // namespace raft::neighbors::cagra::detail
11 changes: 6 additions & 5 deletions cpp/include/raft/neighbors/detail/cagra/factory.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ class factory {
search_params const& params,
int64_t dim,
int64_t graph_degree,
uint32_t topk)
uint32_t topk,
const raft::distance::DistanceType metric)
{
search_plan_impl_base plan(params, dim, graph_degree, topk);
search_plan_impl_base plan(params, dim, graph_degree, topk, metric);
switch (plan.dataset_block_dim) {
case 128:
switch (plan.team_size) {
Expand Down Expand Up @@ -77,17 +78,17 @@ class factory {
return std::unique_ptr<search_plan_impl<DATASET_DESCRIPTOR_T, CagraSampleFilterT>>(
new single_cta_search::
search<TEAM_SIZE, DATASET_BLOCK_DIM, DATASET_DESCRIPTOR_T, CagraSampleFilterT>(
res, plan, plan.dim, plan.graph_degree, plan.topk));
res, plan, plan.dim, plan.graph_degree, plan.topk, plan.metric));
} else if (plan.algo == search_algo::MULTI_CTA) {
return std::unique_ptr<search_plan_impl<DATASET_DESCRIPTOR_T, CagraSampleFilterT>>(
new multi_cta_search::
search<TEAM_SIZE, DATASET_BLOCK_DIM, DATASET_DESCRIPTOR_T, CagraSampleFilterT>(
res, plan, plan.dim, plan.graph_degree, plan.topk));
res, plan, plan.dim, plan.graph_degree, plan.topk, plan.metric));
} else {
return std::unique_ptr<search_plan_impl<DATASET_DESCRIPTOR_T, CagraSampleFilterT>>(
new multi_kernel_search::
search<TEAM_SIZE, DATASET_BLOCK_DIM, DATASET_DESCRIPTOR_T, CagraSampleFilterT>(
res, plan, plan.dim, plan.graph_degree, plan.topk));
res, plan, plan.dim, plan.graph_degree, plan.topk, plan.metric));
}
}
};
Expand Down
10 changes: 8 additions & 2 deletions cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@
#include "topk_for_cagra/topk_core.cuh" // TODO replace with raft topk if possible
#include "utils.hpp"

#include <raft/core/detail/macros.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/logger.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/device_properties.hpp>
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/linalg/map.cuh>
#include <raft/spatial/knn/detail/ann_utils.cuh>
#include <raft/util/cuda_rt_essentials.hpp>
#include <raft/util/cudart_utils.hpp> // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp
Expand Down Expand Up @@ -96,8 +99,10 @@ struct search : public search_plan_impl<DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T> {
search_params params,
int64_t dim,
int64_t graph_degree,
uint32_t topk)
: search_plan_impl<DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T>(res, params, dim, graph_degree, topk),
uint32_t topk,
raft::distance::DistanceType metric)
: search_plan_impl<DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T>(
res, params, dim, graph_degree, topk, metric),
intermediate_indices(0, resource::get_cuda_stream(res)),
intermediate_distances(0, resource::get_cuda_stream(res)),
topk_workspace(0, resource::get_cuda_stream(res))
Expand Down Expand Up @@ -235,6 +240,7 @@ struct search : public search_plan_impl<DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T> {
min_iterations,
max_iterations,
sample_filter,
this->metric,
stream);
RAFT_CUDA_TRY(cudaPeekAtLastError());

Expand Down
Loading

0 comments on commit e720de7

Please sign in to comment.