Skip to content

Commit

Permalink
CAGRA: reduce argument count in select_and_run() kernel wrappers (rap…
Browse files Browse the repository at this point in the history
…idsai#227)

A small change that reduces the number of arguments in one of the wrapper layers in the detail namespace of CAGRA. The goal is twofold:
  1) Simplify the overly long signature of `selet_and_run` (which has many instances) 
  2) Give access to all search parameters for future upgrades of the search kernel

This is to simplify the integration (and review) of the persistent kernel (rapidsai#215).
No performance or functional changes expected.

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)

URL: rapidsai#227
  • Loading branch information
achirkin authored and divyegala committed Jul 31, 2024
1 parent 8860a09 commit 7e7f8cb
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 97 deletions.
7 changes: 1 addition & 6 deletions cpp/src/neighbors/detail/cagra/search_multi_cta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -230,20 +230,15 @@ struct search : public search_plan_impl<DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T> {
num_queries,
dev_seed_ptr,
num_executed_iterations,
*this,
topk,
thread_block_size,
result_buffer_size,
smem_size,
hash_bitlen,
hashmap.data(),
num_cta_per_query,
num_random_samplings,
rand_xor_mask,
num_seeds,
itopk_size,
search_width,
min_iterations,
max_iterations,
sample_filter,
this->metric,
stream);
Expand Down
7 changes: 1 addition & 6 deletions cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,15 @@ namespace cuvs::neighbors::cagra::detail::multi_cta_search {
const uint32_t num_queries, \
const typename DATASET_DESC_T::INDEX_T* dev_seed_ptr, \
uint32_t* const num_executed_iterations, \
const search_params& ps, \
uint32_t topk, \
uint32_t block_size, \
uint32_t result_buffer_size, \
uint32_t smem_size, \
int64_t hash_bitlen, \
typename DATASET_DESC_T::INDEX_T* hashmap_ptr, \
uint32_t num_cta_per_query, \
uint32_t num_random_samplings, \
uint64_t rand_xor_mask, \
uint32_t num_seeds, \
size_t itopk_size, \
size_t search_width, \
size_t min_iterations, \
size_t max_iterations, \
SAMPLE_FILTER_T sample_filter, \
cuvs::distance::DistanceType metric, \
cudaStream_t stream);
Expand Down
40 changes: 13 additions & 27 deletions cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,33 +27,29 @@ namespace multi_cta_search {
#ifdef CUVS_EXPLICIT_INSTANTIATE_ONLY

template <unsigned TEAM_SIZE,
unsigned MAX_DATASET_DIM,
class DATASET_DESCRIPTOR_T,
class SAMPLE_FILTER_T>
unsigned DATASET_BLOCK_DIM,
typename DATASET_DESCRIPTOR_T,
typename SAMPLE_FILTER_T>
void select_and_run(
DATASET_DESCRIPTOR_T dataset_desc,
raft::device_matrix_view<const typename DATASET_DESCRIPTOR_T::INDEX_T, int64_t, raft::row_major>
graph,
typename DATASET_DESCRIPTOR_T::INDEX_T* const topk_indices_ptr,
typename DATASET_DESCRIPTOR_T::DISTANCE_T* const topk_distances_ptr,
const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr,
typename DATASET_DESCRIPTOR_T::INDEX_T* const topk_indices_ptr, // [num_queries, topk]
typename DATASET_DESCRIPTOR_T::DISTANCE_T* const topk_distances_ptr, // [num_queries, topk]
const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr, // [num_queries, dataset_dim]
const uint32_t num_queries,
const typename DATASET_DESCRIPTOR_T::INDEX_T* dev_seed_ptr,
uint32_t* const num_executed_iterations,
const typename DATASET_DESCRIPTOR_T::INDEX_T* dev_seed_ptr, // [num_queries, num_seeds]
uint32_t* const num_executed_iterations, // [num_queries,]
const search_params& ps,
uint32_t topk,
uint32_t block_size,
// multi_cta_search (params struct)
uint32_t block_size, //
uint32_t result_buffer_size,
uint32_t smem_size,
int64_t hash_bitlen,
typename DATASET_DESCRIPTOR_T::INDEX_T* hashmap_ptr,
uint32_t num_cta_per_query,
uint32_t num_random_samplings,
uint64_t rand_xor_mask,
uint32_t num_seeds,
size_t itopk_size,
size_t search_width,
size_t min_iterations,
size_t max_iterations,
SAMPLE_FILTER_T sample_filter,
cuvs::distance::DistanceType metric,
cudaStream_t stream) RAFT_EXPLICIT;
Expand All @@ -75,20 +71,15 @@ void select_and_run(
const uint32_t num_queries, \
const INDEX_T* dev_seed_ptr, \
uint32_t* const num_executed_iterations, \
const search_params& ps, \
uint32_t topk, \
uint32_t block_size, \
uint32_t result_buffer_size, \
uint32_t smem_size, \
int64_t hash_bitlen, \
INDEX_T* hashmap_ptr, \
uint32_t num_cta_per_query, \
uint32_t num_random_samplings, \
uint64_t rand_xor_mask, \
uint32_t num_seeds, \
size_t itopk_size, \
size_t search_width, \
size_t min_iterations, \
size_t max_iterations, \
SAMPLE_FILTER_T sample_filter, \
cuvs::distance::DistanceType metric, \
cudaStream_t stream);
Expand Down Expand Up @@ -160,20 +151,15 @@ instantiate_kernel_selection(
const uint32_t num_queries, \
const INDEX_T* dev_seed_ptr, \
uint32_t* const num_executed_iterations, \
const search_params& ps, \
uint32_t topk, \
uint32_t block_size, \
uint32_t result_buffer_size, \
uint32_t smem_size, \
int64_t hash_bitlen, \
INDEX_T* hashmap_ptr, \
uint32_t num_cta_per_query, \
uint32_t num_random_samplings, \
uint64_t rand_xor_mask, \
uint32_t num_seeds, \
size_t itopk_size, \
size_t search_width, \
size_t min_iterations, \
size_t max_iterations, \
SAMPLE_FILTER_T sample_filter, \
cuvs::distance::DistanceType metric, \
cudaStream_t stream);
Expand Down
19 changes: 7 additions & 12 deletions cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ void select_and_run(
const uint32_t num_queries,
const typename DATASET_DESCRIPTOR_T::INDEX_T* dev_seed_ptr, // [num_queries, num_seeds]
uint32_t* const num_executed_iterations, // [num_queries,]
const search_params& ps,
uint32_t topk,
// multi_cta_search (params struct)
uint32_t block_size, //
Expand All @@ -466,13 +467,7 @@ void select_and_run(
int64_t hash_bitlen,
typename DATASET_DESCRIPTOR_T::INDEX_T* hashmap_ptr,
uint32_t num_cta_per_query,
uint32_t num_random_samplings,
uint64_t rand_xor_mask,
uint32_t num_seeds,
size_t itopk_size,
size_t search_width,
size_t min_iterations,
size_t max_iterations,
SAMPLE_FILTER_T sample_filter,
cuvs::distance::DistanceType metric,
cudaStream_t stream)
Expand Down Expand Up @@ -507,16 +502,16 @@ void select_and_run(
queries_ptr,
graph.data_handle(),
graph.extent(1),
num_random_samplings,
rand_xor_mask,
ps.num_random_samplings,
ps.rand_xor_mask,
dev_seed_ptr,
num_seeds,
hashmap_ptr,
hash_bitlen,
itopk_size,
search_width,
min_iterations,
max_iterations,
ps.itopk_size,
ps.search_width,
ps.min_iterations,
ps.max_iterations,
num_executed_iterations,
sample_filter,
metric);
Expand Down
7 changes: 1 addition & 6 deletions cpp/src/neighbors/detail/cagra/search_single_cta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ struct search : search_plan_impl<DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T> {
num_queries,
dev_seed_ptr,
num_executed_iterations,
*this,
topk,
num_itopk_candidates,
static_cast<uint32_t>(thread_block_size),
Expand All @@ -241,13 +242,7 @@ struct search : search_plan_impl<DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T> {
hashmap.data(),
small_hash_bitlen,
small_hash_reset_interval,
num_random_samplings,
rand_xor_mask,
num_seeds,
itopk_size,
search_width,
min_iterations,
max_iterations,
sample_filter,
this->metric,
stream);
Expand Down
7 changes: 1 addition & 6 deletions cpp/src/neighbors/detail/cagra/search_single_cta_inst.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ namespace cuvs::neighbors::cagra::detail::single_cta_search {
const uint32_t num_queries, \
const typename DATASET_DESC_T::INDEX_T* dev_seed_ptr, \
uint32_t* const num_executed_iterations, \
const search_params& ps, \
uint32_t topk, \
uint32_t num_itopk_candidates, \
uint32_t block_size, \
Expand All @@ -40,13 +41,7 @@ namespace cuvs::neighbors::cagra::detail::single_cta_search {
typename DATASET_DESC_T::INDEX_T* hashmap_ptr, \
size_t small_hash_bitlen, \
size_t small_hash_reset_interval, \
uint32_t num_random_samplings, \
uint64_t rand_xor_mask, \
uint32_t num_seeds, \
size_t itopk_size, \
size_t search_width, \
size_t min_iterations, \
size_t max_iterations, \
SAMPLE_FILTER_T sample_filter, \
cuvs::distance::DistanceType metric, \
cudaStream_t stream);
Expand Down
27 changes: 6 additions & 21 deletions cpp/src/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ namespace single_cta_search {
#ifdef CUVS_EXPLICIT_INSTANTIATE_ONLY

template <unsigned TEAM_SIZE,
unsigned MAX_DATASET_DIM,
unsigned DATASET_BLOCK_DIM,
typename DATASET_DESCRIPTOR_T,
typename SAMPLE_FILTER_T>
void select_and_run( // raft::resources const& res,
void select_and_run(
DATASET_DESCRIPTOR_T dataset_desc,
raft::device_matrix_view<const typename DATASET_DESCRIPTOR_T::INDEX_T, int64_t, raft::row_major>
graph,
Expand All @@ -39,21 +39,16 @@ void select_and_run( // raft::resources const& res,
const uint32_t num_queries,
const typename DATASET_DESCRIPTOR_T::INDEX_T* dev_seed_ptr, // [num_queries, num_seeds]
uint32_t* const num_executed_iterations, // [num_queries,]
const search_params& ps,
uint32_t topk,
uint32_t num_itopk_candidates,
uint32_t block_size,
uint32_t block_size, //
uint32_t smem_size,
int64_t hash_bitlen,
typename DATASET_DESCRIPTOR_T::INDEX_T* hashmap_ptr,
size_t small_hash_bitlen,
size_t small_hash_reset_interval,
uint32_t num_random_samplings,
uint64_t rand_xor_mask,
uint32_t num_seeds,
size_t itopk_size,
size_t search_width,
size_t min_iterations,
size_t max_iterations,
SAMPLE_FILTER_T sample_filter,
cuvs::distance::DistanceType metric,
cudaStream_t stream) RAFT_EXPLICIT;
Expand All @@ -76,6 +71,7 @@ void select_and_run( // raft::resources const& res,
const uint32_t num_queries, \
const INDEX_T* dev_seed_ptr, \
uint32_t* const num_executed_iterations, \
const search_params& ps, \
uint32_t topk, \
uint32_t num_itopk_candidates, \
uint32_t block_size, \
Expand All @@ -84,13 +80,7 @@ void select_and_run( // raft::resources const& res,
INDEX_T* hashmap_ptr, \
size_t small_hash_bitlen, \
size_t small_hash_reset_interval, \
uint32_t num_random_samplings, \
uint64_t rand_xor_mask, \
uint32_t num_seeds, \
size_t itopk_size, \
size_t search_width, \
size_t min_iterations, \
size_t max_iterations, \
SAMPLE_FILTER_T sample_filter, \
cuvs::distance::DistanceType metric, \
cudaStream_t stream);
Expand Down Expand Up @@ -162,6 +152,7 @@ instantiate_single_cta_select_and_run(
const uint32_t num_queries, \
const INDEX_T* dev_seed_ptr, \
uint32_t* const num_executed_iterations, \
const search_params& ps, \
uint32_t topk, \
uint32_t num_itopk_candidates, \
uint32_t block_size, \
Expand All @@ -170,13 +161,7 @@ instantiate_single_cta_select_and_run(
INDEX_T* hashmap_ptr, \
size_t small_hash_bitlen, \
size_t small_hash_reset_interval, \
uint32_t num_random_samplings, \
uint64_t rand_xor_mask, \
uint32_t num_seeds, \
size_t itopk_size, \
size_t search_width, \
size_t min_iterations, \
size_t max_iterations, \
SAMPLE_FILTER_T sample_filter, \
cuvs::distance::DistanceType metric, \
cudaStream_t stream);
Expand Down
Loading

0 comments on commit 7e7f8cb

Please sign in to comment.