diff --git a/cpp/src/neighbors/detail/cagra/add_nodes.cuh b/cpp/src/neighbors/detail/cagra/add_nodes.cuh index b03b8214b..66da91c57 100644 --- a/cpp/src/neighbors/detail/cagra/add_nodes.cuh +++ b/cpp/src/neighbors/detail/cagra/add_nodes.cuh @@ -127,6 +127,7 @@ void add_node_core( raft::resource::sync_stream(handle); // Step 2: rank-based reordering + bool search_returned_not_enough_values = false; #pragma omp parallel { std::vector> detourable_node_count_list(base_degree); @@ -135,10 +136,29 @@ void add_node_core( // Count detourable edges for (std::uint32_t i = 0; i < base_degree; i++) { std::uint32_t detourable_node_count = 0; - const auto a_id = host_neighbor_indices(vec_i, i); + // TODO: the invalid indices may be produced by neighbors::cagra::search above. + // This may happen if the search function hasn't returned enough values. + // - A valid reason could be: the index size is smaller than the base degree. + // - A bad reason could be: search iterations is set to a too low value or some + // other problem with the search config. + // - This could also be a bug in the search function + // In the following, we check the indices and assign low priorities to invalid links, + // so that they are not likely to appear in the final graph. + const auto a_id = host_neighbor_indices(vec_i, i); + if (a_id >= idx.size()) { + detourable_node_count_list[i] = std::make_pair(a_id, base_degree); +#pragma omp atomic write + search_returned_not_enough_values = true; + continue; + } for (std::uint32_t j = 0; j < i; j++) { const auto b0_id = host_neighbor_indices(vec_i, j); - assert(b0_id < idx.size()); + if (b0_id >= idx.size()) { +#pragma omp atomic write + search_returned_not_enough_values = true; + detourable_node_count++; + continue; + } for (std::uint32_t k = 0; k < degree; k++) { const auto b1_id = updated_graph(b0_id, k); if (a_id == b1_id) { @@ -160,6 +180,11 @@ void add_node_core( } } } + if (search_returned_not_enough_values) { + RAFT_FAIL( + "CAGRA search returned not enough valid indices to add new nodes to the graph. " + "The resulting graph may contain invalid edges at the new nodes."); + } // Step 3: Add reverse edges const std::uint32_t rev_edge_search_range = degree / 2; @@ -248,7 +273,9 @@ void add_graph_nodes( raft::host_matrix_view updated_graph_view, const cagra::extend_params& params) { - assert(input_updated_dataset_view.extent(0) >= index.size()); + if (input_updated_dataset_view.extent(0) < index.size()) { + RAFT_FAIL("Updated dataset must be not smaller than the previous index state."); + } const std::size_t initial_dataset_size = index.size(); const std::size_t new_dataset_size = input_updated_dataset_view.extent(0); diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index 5778d85a6..b4f701819 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -75,7 +75,7 @@ void search_main_core(raft::resources const& res, using CagraSampleFilterT_s = typename CagraSampleFilterT_Selector::type; std::unique_ptr> plan = factory::create( - res, params, dataset_desc, queries.extent(1), graph.extent(1), topk); + res, params, dataset_desc, queries.extent(1), graph.extent(0), graph.extent(1), topk); plan->check(topk); diff --git a/cpp/src/neighbors/detail/cagra/device_common.hpp b/cpp/src/neighbors/detail/cagra/device_common.hpp index 7ec3d4d9e..9dcc5123b 100644 --- a/cpp/src/neighbors/detail/cagra/device_common.hpp +++ b/cpp/src/neighbors/detail/cagra/device_common.hpp @@ -109,7 +109,9 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes( const IndexT* __restrict__ seed_ptr, // [num_seeds] const uint32_t num_seeds, IndexT* __restrict__ visited_hash_ptr, - const uint32_t hash_bitlen, + const uint32_t visited_hash_bitlen, + IndexT* __restrict__ traversed_hash_ptr, + const uint32_t traversed_hash_bitlen, const uint32_t block_id = 0, const uint32_t num_blocks = 1) { @@ -145,14 +147,21 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes( const unsigned lane_id = threadIdx.x & ((1u << team_size_bits) - 1u); if (valid_i && lane_id == 0) { - if (best_index_team_local != raft::upper_bound() && - hashmap::insert(visited_hash_ptr, hash_bitlen, best_index_team_local)) { - result_distances_ptr[i] = best_norm2_team_local; - result_indices_ptr[i] = best_index_team_local; - } else { - result_distances_ptr[i] = raft::upper_bound(); - result_indices_ptr[i] = raft::upper_bound(); + if (best_index_team_local != raft::upper_bound()) { + if (hashmap::insert(visited_hash_ptr, visited_hash_bitlen, best_index_team_local) == 0) { + // Deactivate this entry as insertion into visited hash table has failed. + best_norm2_team_local = raft::upper_bound(); + best_index_team_local = raft::upper_bound(); + } else if ((traversed_hash_ptr != nullptr) && + hashmap::search( + traversed_hash_ptr, traversed_hash_bitlen, best_index_team_local)) { + // Deactivate this entry as it has been already used by otehrs. + best_norm2_team_local = raft::upper_bound(); + best_index_team_local = raft::upper_bound(); + } } + result_distances_ptr[i] = best_norm2_team_local; + result_indices_ptr[i] = best_index_team_local; } } } @@ -168,7 +177,9 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes( const uint32_t knn_k, // hashmap IndexT* __restrict__ visited_hashmap_ptr, - const uint32_t hash_bitlen, + const uint32_t visited_hash_bitlen, + IndexT* __restrict__ traversed_hashmap_ptr, + const uint32_t traversed_hash_bitlen, const IndexT* __restrict__ parent_indices, const IndexT* __restrict__ internal_topk_list, const uint32_t search_width) @@ -186,7 +197,13 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes( child_id = knn_graph[(i % knn_k) + (static_cast(knn_k) * parent_id)]; } if (child_id != invalid_index) { - if (hashmap::insert(visited_hashmap_ptr, hash_bitlen, child_id) == 0) { + if (hashmap::insert(visited_hashmap_ptr, visited_hash_bitlen, child_id) == 0) { + // Deactivate this entry as insertion into visited hash table has failed. + child_id = invalid_index; + } else if ((traversed_hashmap_ptr != nullptr) && + hashmap::search( + traversed_hashmap_ptr, traversed_hash_bitlen, child_id)) { + // Deactivate this entry as this has been already used by others. child_id = invalid_index; } } diff --git a/cpp/src/neighbors/detail/cagra/factory.cuh b/cpp/src/neighbors/detail/cagra/factory.cuh index e6e7ff64f..d2ae5c55b 100644 --- a/cpp/src/neighbors/detail/cagra/factory.cuh +++ b/cpp/src/neighbors/detail/cagra/factory.cuh @@ -40,10 +40,11 @@ class factory { search_params const& params, const dataset_descriptor_host& dataset_desc, int64_t dim, + int64_t dataset_size, int64_t graph_degree, uint32_t topk) { - search_plan_impl_base plan(params, dim, graph_degree, topk); + search_plan_impl_base plan(params, dim, dataset_size, graph_degree, topk); return dispatch_kernel(res, plan, dataset_desc); } @@ -56,15 +57,15 @@ class factory { if (plan.algo == search_algo::SINGLE_CTA) { return std::make_unique< single_cta_search::search>( - res, plan, dataset_desc, plan.dim, plan.graph_degree, plan.topk); + res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk); } else if (plan.algo == search_algo::MULTI_CTA) { return std::make_unique< multi_cta_search::search>( - res, plan, dataset_desc, plan.dim, plan.graph_degree, plan.topk); + res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk); } else { return std::make_unique< multi_kernel_search::search>( - res, plan, dataset_desc, plan.dim, plan.graph_degree, plan.topk); + res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk); } } }; diff --git a/cpp/src/neighbors/detail/cagra/hashmap.hpp b/cpp/src/neighbors/detail/cagra/hashmap.hpp index 2c62dda90..da736ef5e 100644 --- a/cpp/src/neighbors/detail/cagra/hashmap.hpp +++ b/cpp/src/neighbors/detail/cagra/hashmap.hpp @@ -23,6 +23,8 @@ #include +#define HASHMAP_LINEAR_PROBING + // #pragma GCC diagnostic push // #pragma GCC diagnostic ignored // #pragma GCC diagnostic pop @@ -42,7 +44,7 @@ RAFT_DEVICE_INLINE_FUNCTION void init(IdxT* const table, } } -template +template RAFT_DEVICE_INLINE_FUNCTION uint32_t insert(IdxT* const table, const uint32_t bitlen, const IdxT key) @@ -50,7 +52,7 @@ RAFT_DEVICE_INLINE_FUNCTION uint32_t insert(IdxT* const table, // Open addressing is used for collision resolution const uint32_t size = get_size(bitlen); const uint32_t bit_mask = size - 1; -#if 1 +#ifdef HASHMAP_LINEAR_PROBING // Linear probing IdxT index = (key ^ (key >> bitlen)) & bit_mask; constexpr uint32_t stride = 1; @@ -59,32 +61,89 @@ RAFT_DEVICE_INLINE_FUNCTION uint32_t insert(IdxT* const table, uint32_t index = key & bit_mask; const uint32_t stride = (key >> bitlen) * 2 + 1; #endif + constexpr IdxT hashval_empty = ~static_cast(0); + const IdxT removed_key = key | utils::gen_index_msb_1_mask::value; for (unsigned i = 0; i < size; i++) { - const IdxT old = atomicCAS(&table[index], ~static_cast(0), key); - if (old == ~static_cast(0)) { + const IdxT old = atomicCAS(&table[index], hashval_empty, key); + if (old == hashval_empty) { return 1; } else if (old == key) { return 0; + } else if (SUPPORT_REMOVE) { + // Checks if this key has been removed before. + const uint32_t old = atomicCAS(&table[index], removed_key, key); + if (old == removed_key) { + return 1; + } else if (old == key) { + return 0; + } } index = (index + stride) & bit_mask; } return 0; } -template -RAFT_DEVICE_INLINE_FUNCTION uint32_t insert(IdxT* const table, - const uint32_t bitlen, - const IdxT key) +template +RAFT_DEVICE_INLINE_FUNCTION uint32_t search(IdxT* table, const uint32_t bitlen, const IdxT key) { - IdxT ret = 0; - if (threadIdx.x % TEAM_SIZE == 0) { ret = insert(table, bitlen, key); } - for (unsigned offset = 1; offset < TEAM_SIZE; offset *= 2) { - ret |= __shfl_xor_sync(0xffffffff, ret, offset); + const uint32_t size = get_size(bitlen); + const uint32_t bit_mask = size - 1; +#ifdef HASHMAP_LINEAR_PROBING + // Linear probing + IdxT index = (key ^ (key >> bitlen)) & bit_mask; + constexpr uint32_t stride = 1; +#else + // Double hashing + IdxT index = key & bit_mask; + const uint32_t stride = (key >> bitlen) * 2 + 1; +#endif + constexpr IdxT hashval_empty = ~static_cast(0); + const IdxT removed_key = key | utils::gen_index_msb_1_mask::value; + for (unsigned i = 0; i < size; i++) { + const IdxT val = table[index]; + if (val == key) { + return 1; + } else if (val == hashval_empty) { + return 0; + } else if (SUPPORT_REMOVE) { + // Check if this key has been removed. + if (val == removed_key) { return 0; } + } + index = (index + stride) & bit_mask; } - return ret; + return 0; } template +RAFT_DEVICE_INLINE_FUNCTION uint32_t remove(IdxT* table, const uint32_t bitlen, const IdxT key) +{ + const uint32_t size = get_size(bitlen); + const uint32_t bit_mask = size - 1; +#ifdef HASHMAP_LINEAR_PROBING + // Linear probing + IdxT index = (key ^ (key >> bitlen)) & bit_mask; + constexpr uint32_t stride = 1; +#else + // Double hashing + IdxT index = key & bit_mask; + const uint32_t stride = (key >> bitlen) * 2 + 1; +#endif + constexpr IdxT hashval_empty = ~static_cast(0); + const IdxT removed_key = key | utils::gen_index_msb_1_mask::value; + for (unsigned i = 0; i < size; i++) { + // To remove a key, set the MSB to 1. + const uint32_t old = atomicCAS(&table[index], key, removed_key); + if (old == key) { + return 1; + } else if (old == hashval_empty) { + return 0; + } + index = (index + stride) & bit_mask; + } + return 0; +} + +template RAFT_DEVICE_INLINE_FUNCTION uint32_t insert(unsigned team_size, IdxT* const table, const uint32_t bitlen, const IdxT key) { diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh index ecfd856f1..8d425ca67 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh @@ -102,24 +102,24 @@ struct search : public search_plan_impl& dataset_desc, int64_t dim, + int64_t dataset_size, int64_t graph_degree, uint32_t topk) - : base_type(res, params, dataset_desc, dim, graph_degree, topk), + : base_type(res, params, dataset_desc, dim, dataset_size, graph_degree, topk), intermediate_indices(res), intermediate_distances(res), topk_workspace(res) - { set_params(res, params); } void set_params(raft::resources const& res, const search_params& params) { - constexpr unsigned muti_cta_itopk_size = 32; - this->itopk_size = muti_cta_itopk_size; - search_width = 1; + constexpr unsigned multi_cta_itopk_size = 32; + this->itopk_size = multi_cta_itopk_size; + search_width = 1; num_cta_per_query = - max(params.search_width, raft::ceildiv(params.itopk_size, (size_t)muti_cta_itopk_size)); + max(params.search_width, raft::ceildiv(params.itopk_size, (size_t)multi_cta_itopk_size)); result_buffer_size = itopk_size + search_width * graph_degree; typedef raft::Pow2<32> AlignBytes; unsigned result_buffer_size_32 = AlignBytes::roundUp(result_buffer_size); @@ -128,7 +128,8 @@ struct search : public search_plan_impl +template RAFT_DEVICE_INLINE_FUNCTION void pickup_next_parents( - INDEX_T* const next_parent_indices, // [search_width] - const uint32_t search_width, - INDEX_T* const itopk_indices, // [num_itopk] - const size_t num_itopk, - uint32_t* const terminate_flag) + INDEX_T* const next_parent_indices, // [num_parents] + const uint32_t num_parents, + INDEX_T* const itopk_indices, // [num_itopk] + DISTANCE_T* const itopk_distances, // [num_itopk] + const uint32_t num_itopk, // (*) num_itopk <= 32 + INDEX_T* const hash_ptr, + const uint32_t hash_bitlen) { constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; const unsigned warp_id = threadIdx.x / 32; if (warp_id > 0) { return; } const unsigned lane_id = threadIdx.x % 32; - for (uint32_t i = lane_id; i < search_width; i += 32) { - next_parent_indices[i] = utils::get_max_value(); - } - uint32_t max_itopk = num_itopk; - if (max_itopk % 32) { max_itopk += 32 - (max_itopk % 32); } - uint32_t num_new_parents = 0; - for (uint32_t j = lane_id; j < max_itopk; j += 32) { - INDEX_T index; - int new_parent = 0; - if (j < num_itopk) { - index = itopk_indices[j]; - if ((index & index_msb_1_mask) == 0) { // check if most significant bit is set - new_parent = 1; - } + + // Initialize + if (lane_id < num_parents) { next_parent_indices[lane_id] = ~static_cast(0); } + INDEX_T index = ~static_cast(0); + if (lane_id < num_itopk) { index = itopk_indices[lane_id]; } + + int is_candidate = 0; + if ((index & index_msb_1_mask) == 0) { + if (hashmap::search(hash_ptr, hash_bitlen, index)) { + // Deactivate nodes that have already been used by other CTAs. + index = ~static_cast(0); + itopk_indices[lane_id] = index; + itopk_distances[lane_id] = utils::get_max_value(); + } else { + is_candidate = 1; } - const uint32_t ballot_mask = __ballot_sync(0xffffffff, new_parent); - if (new_parent) { - const auto i = __popc(ballot_mask & ((1 << lane_id) - 1)) + num_new_parents; - if (i < search_width) { - next_parent_indices[i] = j; - itopk_indices[j] |= index_msb_1_mask; // set most significant bit as used node + } + + uint32_t num_next_parents = 0; + while (num_next_parents < num_parents) { + const uint32_t ballot_mask = __ballot_sync(0xffffffff, is_candidate); + int num_candidates = __popc(ballot_mask); + if (num_candidates == 0) { return; } + int is_found = 0; + if (is_candidate) { + const auto candidate_id = __popc(ballot_mask & ((1 << lane_id) - 1)); + if (candidate_id == 0) { + if (hashmap::insert(hash_ptr, hash_bitlen, index)) { + // Use this candidate as next parent + next_parent_indices[num_next_parents] = lane_id; + index |= index_msb_1_mask; // set most significant bit as used node + is_found = 1; + } else { + // Deactivate the node since it has been used by other CTA. + index = ~static_cast(0); + itopk_distances[lane_id] = utils::get_max_value(); + } + itopk_indices[lane_id] = index; + is_candidate = 0; } } - num_new_parents += __popc(ballot_mask); - if (num_new_parents >= search_width) { break; } + if (__ballot_sync(0xffffffff, is_found)) { num_next_parents += 1; } } - if (threadIdx.x == 0 && (num_new_parents == 0)) { *terminate_flag = 1; } } template @@ -121,12 +139,12 @@ RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort( } /* Warp Sort */ bitonic::warp_sort(key, val); - /* Store itopk sorted results */ + /* Store sorted results */ for (unsigned i = 0; i < N; i++) { unsigned j = (N * lane_id) + i; - if (j < num_itopk) { - distances[j] = key[i]; - indices[j] = val[i]; + if (j < num_elements) { + indices[j] = val[i]; + if (j < num_itopk) { distances[j] = key[i]; } } } } @@ -148,9 +166,10 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( const uint64_t rand_xor_mask, const typename DATASET_DESCRIPTOR_T::INDEX_T* seed_ptr, // [num_queries, num_seeds] const uint32_t num_seeds, + const uint32_t visited_hash_bitlen, typename DATASET_DESCRIPTOR_T::INDEX_T* const - visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] - const uint32_t hash_bitlen, + traversed_hashmap_ptr, // [num_queries, 1 << traversed_hash_bitlen] + const uint32_t traversed_hash_bitlen, const uint32_t itopk_size, const uint32_t search_width, const uint32_t min_iteration, @@ -185,11 +204,11 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( extern __shared__ uint8_t smem[]; // Layout of result_buffer - // +----------------+------------------------------+---------+ - // | internal_top_k | neighbors of parent nodes | padding | + // +----------------+-------------------------------+---------+ + // | internal_top_k | neighbors of parent nodes | padding | // | | | upto 32 | - // +----------------+------------------------------+---------+ - // |<--- result_buffer_size --->| + // +----------------+-------------------------------+---------+ + // |<--- result_buffer_size --->| const auto result_buffer_size = itopk_size + (search_width * graph_degree); const auto result_buffer_size_32 = raft::round_up_safe(result_buffer_size, 32); assert(result_buffer_size_32 <= MAX_ELEMENTS); @@ -201,10 +220,10 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( reinterpret_cast(smem + dataset_desc->smem_ws_size_in_bytes()); auto* __restrict__ result_distances_buffer = reinterpret_cast(result_indices_buffer + result_buffer_size_32); - auto* __restrict__ parent_indices_buffer = + auto* __restrict__ local_visited_hashmap_ptr = reinterpret_cast(result_distances_buffer + result_buffer_size_32); - auto* __restrict__ terminate_flag = - reinterpret_cast(parent_indices_buffer + search_width); + auto* __restrict__ parent_indices_buffer = + reinterpret_cast(local_visited_hashmap_ptr + hashmap::get_size(visited_hash_bitlen)); #if 0 /* debug */ @@ -214,9 +233,10 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( } #endif - if (threadIdx.x == 0) { terminate_flag[0] = 0; } - INDEX_T* const local_visited_hashmap_ptr = - visited_hashmap_ptr + (hashmap::get_size(hash_bitlen) * query_id); + hashmap::init(local_visited_hashmap_ptr, visited_hash_bitlen); + + INDEX_T* const local_traversed_hashmap_ptr = + traversed_hashmap_ptr + (hashmap::get_size(traversed_hash_bitlen) * query_id); __syncthreads(); _CLK_REC(clk_init); @@ -235,35 +255,51 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( local_seed_ptr, num_seeds, local_visited_hashmap_ptr, - hash_bitlen, + visited_hash_bitlen, + local_traversed_hashmap_ptr, + traversed_hash_bitlen, block_id, num_blocks); __syncthreads(); _CLK_REC(clk_compute_1st_distance); - uint32_t iter = 0; + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + uint32_t iter = 0; while (1) { - // topk with bitonic sort + // Topk with bitonic sort (1st warp only) _CLK_START(); topk_by_bitonic_sort(result_distances_buffer, result_indices_buffer, itopk_size + (search_width * graph_degree), itopk_size); _CLK_REC(clk_topk); + __syncthreads(); - if (iter + 1 == max_iteration) { - __syncthreads(); - break; + if (iter + 1 == max_iteration) { break; } + + // Remove entries kicked out of the itopk list from the traversed hash table. + for (unsigned i = threadIdx.x; i < search_width * graph_degree; i += blockDim.x) { + INDEX_T index = result_indices_buffer[itopk_size + i]; + if ((index & index_msb_1_mask) == 0 || (index == ~static_cast(0))) { continue; } + index &= ~index_msb_1_mask; + hashmap::remove(local_traversed_hashmap_ptr, traversed_hash_bitlen, index); } - // pick up next parents + // Pick up next parents (1st warp only) _CLK_START(); - pickup_next_parents( - parent_indices_buffer, search_width, result_indices_buffer, itopk_size, terminate_flag); + pickup_next_parents(parent_indices_buffer, + search_width, + result_indices_buffer, + result_distances_buffer, + itopk_size, + local_traversed_hashmap_ptr, + traversed_hash_bitlen); _CLK_REC(clk_pickup_parents); - __syncthreads(); - if (*terminate_flag && iter >= min_iteration) { break; } + + if ((parent_indices_buffer[0] == ~static_cast(0)) && (iter >= min_iteration)) { + break; + } // compute the norms between child nodes and query node _CLK_START(); @@ -273,7 +309,9 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( knn_graph, graph_degree, local_visited_hashmap_ptr, - hash_bitlen, + visited_hash_bitlen, + local_traversed_hashmap_ptr, + traversed_hash_bitlen, parent_indices_buffer, result_indices_buffer, search_width); @@ -303,36 +341,19 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( iter++; } - // Post process for filtering - if constexpr (!std::is_same::value) { - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - const INDEX_T invalid_index = utils::get_max_value(); - - for (unsigned i = threadIdx.x; i < itopk_size + search_width * graph_degree; i += blockDim.x) { - const auto node_id = result_indices_buffer[i] & ~index_msb_1_mask; - if (node_id != (invalid_index & ~index_msb_1_mask) && !sample_filter(query_id, node_id)) { - // If the parent must not be in the resulting top-k list, remove from the parent list - result_distances_buffer[i] = utils::get_max_value(); - result_indices_buffer[i] = invalid_index; - } - } - - __syncthreads(); - topk_by_bitonic_sort(result_distances_buffer, - result_indices_buffer, - itopk_size + (search_width * graph_degree), - itopk_size); - __syncthreads(); - } - for (uint32_t i = threadIdx.x; i < itopk_size; i += blockDim.x) { - uint32_t j = i + (itopk_size * (cta_id + (num_cta_per_query * query_id))); - if (result_distances_ptr != nullptr) { result_distances_ptr[j] = result_distances_buffer[i]; } - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - - result_indices_ptr[j] = - result_indices_buffer[i] & ~index_msb_1_mask; // clear most significant bit + uint32_t j = i + (itopk_size * (cta_id + (num_cta_per_query * query_id))); + INDEX_T index = result_indices_buffer[i]; + DISTANCE_T distance = result_distances_buffer[i]; + if (index & index_msb_1_mask) { + index &= ~index_msb_1_mask; // clear most significant bit + } else { + // This entry has not been used as parent, so deactivate this. + index = ~static_cast(0); + distance = utils::get_max_value(); + } + result_indices_ptr[j] = index; + if (result_distances_ptr != nullptr) { result_distances_ptr[j] = distance; } } if (threadIdx.x == 0 && cta_id == 0 && num_executed_iterations != nullptr) { @@ -427,8 +448,9 @@ void select_and_run(const dataset_descriptor_host& dat uint32_t block_size, // uint32_t result_buffer_size, uint32_t smem_size, - int64_t hash_bitlen, - IndexT* hashmap_ptr, + uint32_t visited_hash_bitlen, + int64_t traversed_hash_bitlen, + IndexT* traversed_hashmap_ptr, uint32_t num_cta_per_query, uint32_t num_seeds, SampleFilterT sample_filter, @@ -441,9 +463,13 @@ void select_and_run(const dataset_descriptor_host& dat RAFT_CUDA_TRY( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); // Initialize hash table - const uint32_t hash_size = hashmap::get_size(hash_bitlen); - set_value_batch( - hashmap_ptr, hash_size, utils::get_max_value(), hash_size, num_queries, stream); + const uint32_t traversed_hash_size = hashmap::get_size(traversed_hash_bitlen); + set_value_batch(traversed_hashmap_ptr, + traversed_hash_size, + ~static_cast(0), + traversed_hash_size, + num_queries, + stream); dim3 block_dims(block_size, 1, 1); dim3 grid_dims(num_cta_per_query, num_queries, 1); @@ -463,8 +489,9 @@ void select_and_run(const dataset_descriptor_host& dat ps.rand_xor_mask, dev_seed_ptr, num_seeds, - hashmap_ptr, - hash_bitlen, + visited_hash_bitlen, + traversed_hashmap_ptr, + traversed_hash_bitlen, ps.itopk_size, ps.search_width, ps.min_iterations, diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel.cuh index 1a1dcd579..e5dc29f27 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel.cuh @@ -36,8 +36,9 @@ void select_and_run(const dataset_descriptor_host& dat uint32_t block_size, // uint32_t result_buffer_size, uint32_t smem_size, - int64_t hash_bitlen, - IndexT* hashmap_ptr, + uint32_t visited_hash_bitlen, + int64_t traversed_hash_bitlen, + IndexT* traversed_hashmap_ptr, uint32_t num_cta_per_query, uint32_t num_seeds, SampleFilterT sample_filter, diff --git a/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh index c6fe21642..be92be999 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh @@ -635,9 +635,10 @@ struct search : search_plan_impl { search_params params, const dataset_descriptor_host& dataset_desc, int64_t dim, + int64_t dataset_size, int64_t graph_degree, uint32_t topk) - : base_type(res, params, dataset_desc, dim, graph_degree, topk), + : base_type(res, params, dataset_desc, dim, dataset_size, graph_degree, topk), result_indices(res), result_distances(res), parent_node_list(res), diff --git a/cpp/src/neighbors/detail/cagra/search_plan.cuh b/cpp/src/neighbors/detail/cagra/search_plan.cuh index 99254aa50..5b6b58a13 100644 --- a/cpp/src/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/src/neighbors/detail/cagra/search_plan.cuh @@ -108,11 +108,17 @@ struct lightweight_uvector { }; struct search_plan_impl_base : public search_params { + int64_t dataset_size; int64_t dim; int64_t graph_degree; uint32_t topk; - search_plan_impl_base(search_params params, int64_t dim, int64_t graph_degree, uint32_t topk) - : search_params(params), dim(dim), graph_degree(graph_degree), topk(topk) + search_plan_impl_base( + search_params params, int64_t dim, int64_t dataset_size, int64_t graph_degree, uint32_t topk) + : search_params(params), + dim(dim), + dataset_size(dataset_size), + graph_degree(graph_degree), + topk(topk) { if (algo == search_algo::AUTO) { const size_t num_sm = raft::getMultiProcessorCount(); @@ -141,7 +147,6 @@ struct search_plan_impl : public search_plan_impl_base { size_t small_hash_bitlen; size_t small_hash_reset_interval; size_t hashmap_size; - uint32_t dataset_size; uint32_t result_buffer_size; uint32_t smem_size; @@ -157,9 +162,10 @@ struct search_plan_impl : public search_plan_impl_base { search_params params, const dataset_descriptor_host& dataset_desc, int64_t dim, + int64_t dataset_size, int64_t graph_degree, uint32_t topk) - : search_plan_impl_base(params, dim, graph_degree, topk), + : search_plan_impl_base(params, dim, dataset_size, graph_degree, topk), hashmap(res), num_executed_iterations(res), dev_seed(res), @@ -193,10 +199,16 @@ struct search_plan_impl : public search_plan_impl_base { uint32_t _max_iterations = max_iterations; if (max_iterations == 0) { if (algo == search_algo::MULTI_CTA) { - _max_iterations = 1 + std::min(32 * 1.1, 32 + 10.0); // TODO(anaruse) + constexpr uint32_t mc_itopk_size = 32; + constexpr uint32_t mc_search_width = 1; + _max_iterations = mc_itopk_size / mc_search_width; } else { - _max_iterations = - 1 + std::min((itopk_size / search_width) * 1.1, (itopk_size / search_width) + 10.0); + _max_iterations = itopk_size / search_width; + } + int64_t num_reachable_nodes = 1; + while (num_reachable_nodes < dataset_size) { + num_reachable_nodes *= graph_degree / 2; + _max_iterations += 1; } } if (max_iterations < min_iterations) { _max_iterations = min_iterations; } @@ -219,88 +231,107 @@ struct search_plan_impl : public search_plan_impl_base { // defines hash_bitlen, small_hash_bitlen, small_hash_reset interval, hash_size inline void calc_hashmap_params(raft::resources const& res) { - // for multiple CTA search - uint32_t mc_num_cta_per_query = 0; - uint32_t mc_search_width = 0; - uint32_t mc_itopk_size = 0; - if (algo == search_algo::MULTI_CTA) { - mc_itopk_size = 32; - mc_search_width = 1; - mc_num_cta_per_query = max(search_width, raft::ceildiv(itopk_size, (size_t)32)); - RAFT_LOG_DEBUG("# mc_itopk_size: %u", mc_itopk_size); - RAFT_LOG_DEBUG("# mc_search_width: %u", mc_search_width); - RAFT_LOG_DEBUG("# mc_num_cta_per_query: %u", mc_num_cta_per_query); - } - // Determine hash size (bit length) hashmap_size = 0; hash_bitlen = 0; small_hash_bitlen = 0; small_hash_reset_interval = 1024 * 1024; float max_fill_rate = hashmap_max_fill_rate; - while (hashmap_mode == hash_mode::AUTO || hashmap_mode == hash_mode::SMALL) { - // - // The small-hash reduces hash table size by initializing the hash table - // for each iteration and re-registering only the nodes that should not be - // re-visited in that iteration. Therefore, the size of small-hash should - // be determined based on the internal topk size and the number of nodes - // visited per iteration. - // - const auto max_visited_nodes = itopk_size + (search_width * graph_degree * 1); - unsigned min_bitlen = 8; // 256 - unsigned max_bitlen = 13; // 8K - if (min_bitlen < hashmap_min_bitlen) { min_bitlen = hashmap_min_bitlen; } - hash_bitlen = min_bitlen; - while (max_visited_nodes > hashmap::get_size(hash_bitlen) * max_fill_rate) { - hash_bitlen += 1; - } - if (hash_bitlen > max_bitlen) { - // Switch to normal hash if hashmap_mode is AUTO, otherwise exit. - if (hashmap_mode == hash_mode::AUTO) { - hash_bitlen = 0; - break; - } else { - RAFT_FAIL( - "small-hash cannot be used because the required hash size exceeds the limit (%u)", - hashmap::get_size(max_bitlen)); - } - } - small_hash_bitlen = hash_bitlen; + if (algo == search_algo::MULTI_CTA) { + const uint32_t mc_itopk_size = 32; + const uint32_t mc_num_cta_per_query = + max(search_width, raft::ceildiv(itopk_size, (size_t)mc_itopk_size)); + RAFT_LOG_DEBUG("# mc_itopk_size: %u", mc_itopk_size); + RAFT_LOG_DEBUG("# mc_num_cta_per_query: %u", mc_num_cta_per_query); // - // Sincc the hash table size is limited to a power of 2, the requirement, - // the maximum fill rate, may be satisfied even if the frequency of hash - // table reset is reduced to once every 2 or more iterations without - // changing the hash table size. In that case, reduce the reset frequency. + // [visited_hash_table] + // In the multi CTA algo, which node has been visited is managed in a hash + // table that each CTA has in the shared memory. This hash table is not + // shared among CTAs. // - small_hash_reset_interval = 1; - while (1) { - const auto max_visited_nodes = - itopk_size + (search_width * graph_degree * (small_hash_reset_interval + 1)); - if (max_visited_nodes > hashmap::get_size(hash_bitlen) * max_fill_rate) { break; } - small_hash_reset_interval += 1; + const uint32_t max_visited_nodes = mc_itopk_size + (graph_degree * max_iterations); + small_hash_bitlen = 11; // 2K + while (max_visited_nodes > hashmap::get_size(small_hash_bitlen) * max_fill_rate) { + small_hash_bitlen += 1; } - break; - } - if (hash_bitlen == 0) { + RAFT_EXPECTS(small_hash_bitlen <= 14, "small_hash_bitlen cannot be largen than 14 (16K)"); // - // The size of hash table is determined based on the maximum number of - // nodes that may be visited before the search is completed and the - // maximum fill rate of the hash table. + // [traversed_hash_table] + // Whether a node has ever been used as the starting point for a traversal + // in each iteration is managed in a separate hash table, which is shared + // among the CTAs. // - uint32_t max_visited_nodes = itopk_size + (search_width * graph_degree * max_iterations); - if (algo == search_algo::MULTI_CTA) { - max_visited_nodes = mc_itopk_size + (mc_search_width * graph_degree * max_iterations); - max_visited_nodes *= mc_num_cta_per_query; - } + const auto max_traversed_nodes = + mc_num_cta_per_query * max((size_t)mc_itopk_size, max_iterations); unsigned min_bitlen = 11; // 2K if (min_bitlen < hashmap_min_bitlen) { min_bitlen = hashmap_min_bitlen; } hash_bitlen = min_bitlen; - while (max_visited_nodes > hashmap::get_size(hash_bitlen) * max_fill_rate) { + while (max_traversed_nodes > hashmap::get_size(hash_bitlen) * max_fill_rate) { hash_bitlen += 1; } - RAFT_EXPECTS(hash_bitlen <= 20, "hash_bitlen cannot be largen than 20 (1M)"); + RAFT_EXPECTS(hash_bitlen <= 25, "hash_bitlen cannot be largen than 25 (32M)"); + } else { + while (hashmap_mode == hash_mode::AUTO || hashmap_mode == hash_mode::SMALL) { + // + // The small-hash reduces hash table size by initializing the hash table + // for each iteration and re-registering only the nodes that should not be + // re-visited in that iteration. Therefore, the size of small-hash should + // be determined based on the internal topk size and the number of nodes + // visited per iteration. + // + const auto max_visited_nodes = itopk_size + (search_width * graph_degree * 1); + unsigned min_bitlen = 8; // 256 + unsigned max_bitlen = 13; // 8K + if (min_bitlen < hashmap_min_bitlen) { min_bitlen = hashmap_min_bitlen; } + hash_bitlen = min_bitlen; + while (max_visited_nodes > hashmap::get_size(hash_bitlen) * max_fill_rate) { + hash_bitlen += 1; + } + if (hash_bitlen > max_bitlen) { + // Switch to normal hash if hashmap_mode is AUTO, otherwise exit. + if (hashmap_mode == hash_mode::AUTO) { + hash_bitlen = 0; + break; + } else { + RAFT_FAIL( + "small-hash cannot be used because the required hash size exceeds the limit (%u)", + hashmap::get_size(max_bitlen)); + } + } + small_hash_bitlen = hash_bitlen; + // + // Sincc the hash table size is limited to a power of 2, the requirement, + // the maximum fill rate, may be satisfied even if the frequency of hash + // table reset is reduced to once every 2 or more iterations without + // changing the hash table size. In that case, reduce the reset frequency. + // + small_hash_reset_interval = 1; + while (1) { + const auto max_visited_nodes = + itopk_size + (search_width * graph_degree * (small_hash_reset_interval + 1)); + if (max_visited_nodes > hashmap::get_size(hash_bitlen) * max_fill_rate) { break; } + small_hash_reset_interval += 1; + } + break; + } + if (hash_bitlen == 0) { + // + // The size of hash table is determined based on the maximum number of + // nodes that may be visited before the search is completed and the + // maximum fill rate of the hash table. + // + uint32_t max_visited_nodes = itopk_size + (search_width * graph_degree * max_iterations); + unsigned min_bitlen = 11; // 2K + if (min_bitlen < hashmap_min_bitlen) { min_bitlen = hashmap_min_bitlen; } + hash_bitlen = min_bitlen; + while (max_visited_nodes > hashmap::get_size(hash_bitlen) * max_fill_rate) { + hash_bitlen += 1; + } + RAFT_EXPECTS(hash_bitlen <= 20, + "hash_bitlen cannot be largen than 20 (1M). You can decrease itopk_size, " + "search_width or max_iterations to reduce the required hashmap size."); + } } - RAFT_LOG_DEBUG("# internal topK = %lu", itopk_size); RAFT_LOG_DEBUG("# parent size = %lu", search_width); RAFT_LOG_DEBUG("# min_iterations = %lu", min_iterations); diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta.cuh index fa71dbaf9..0911d440c 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta.cuh @@ -94,9 +94,10 @@ struct search : search_plan_impl { search_params params, const dataset_descriptor_host& dataset_desc, int64_t dim, + int64_t dataset_size, int64_t graph_degree, uint32_t topk) - : base_type(res, params, dataset_desc, dim, graph_degree, topk) + : base_type(res, params, dataset_desc, dim, dataset_size, graph_degree, topk) { set_params(res); } diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh index 678ed0cb4..94c97ed16 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh @@ -622,7 +622,9 @@ __device__ void search_core( local_seed_ptr, num_seeds, local_visited_hashmap_ptr, - hash_bitlen); + hash_bitlen, + (INDEX_T*)nullptr, + 0); __syncthreads(); _CLK_REC(clk_compute_1st_distance); @@ -749,6 +751,8 @@ __device__ void search_core( graph_degree, local_visited_hashmap_ptr, hash_bitlen, + (INDEX_T*)nullptr, + 0, parent_list_buffer, result_indices_buffer, search_width);