Skip to content

Commit

Permalink
Dynamic Batching (#261)
Browse files Browse the repository at this point in the history
Non-blocking / stream-ordered dynamic batching as a new index type.

## API

This PR implements dynamic batching as a new index type, mirroring the API of other indices.

  * [_building is wrapping_] Building the index means creating a lightweight wrapper on top of an existing index and initializing necessary components, such as IO batch buffers and synchronization primitives.
  * [_type erasure_] The underlying/upstream index type is erased once the dynamic_batching wrapper is created, i.e. there's no way to recover the original search index type or parameters.
  * [_explicit control over batching_] To allow multiple user requests group into a dynamic batch request, the users must use copies of the same dynamic batching index (the user-facing index type is a thin wrapper on top of a shared pointer, hence the copy is shallow and cheap). The search function is thread-safe.

## Feature:  stream-ordered dynamic batching

Non-blocking / stream-ordered dynamic batching means the batching does not involve synchronizing with a GPU stream. The control is returned to the user as soon as the necessary work is submitted to the GPU. This entails a few good-to-know features:

1. The dynamic batching index has the same blocking properties as the upstream index: if the upstream index does not involve stream sync during search, that the dynamic batching index does not involve it as well (otherwise, the dynamic batching search obviously waits till the upstream search synchronizes under the hood).
2. It's responsibility of the user to synchronize the stream before getting the results back - even if the upstream index search does not need it (the batch results are scattered back to the request threads in a post-processing kernel).
3. If the upstream index does not synchronize during search, the dynamic batching index can group the queries even in a single-threaded application (_try it with --no-lap-sync option in the ann-bench benchmarks_).

Overall, stream-ordered dynamic batching makes it easy to modify existing cuVS indexes, because the wrapped index has the same execution behavior as the upstream index.

## Work-in-progress TODO

- [x] Add dynamic batching option to more indices in ann-bench
- [x] Add tests
- [x] **(postponed to 25.02)** Do proper benchmarking and possibly fine-tune the inter-thread communication
- [x] Review the API side (`cpp/include/cuvs/neighbors/dynamic_batching.hpp`) [ready for review CC @cjnolet]
- [x] Review the algorithm side (`cpp/src/neighbors/detail/dynamic_batching.cuh`) [ready for preliminary review: requests for algorithm docsting/clarifications are especially welcome]

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #261
  • Loading branch information
achirkin authored Dec 4, 2024
1 parent a96b720 commit 9fb21ad
Show file tree
Hide file tree
Showing 19 changed files with 2,539 additions and 20 deletions.
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ if(BUILD_SHARED_LIBS)
src/neighbors/iface/iface_pq_uint8_t_int64_t.cu
src/neighbors/detail/cagra/cagra_build.cpp
src/neighbors/detail/cagra/topk_for_cagra/topk.cu
src/neighbors/dynamic_batching.cu
$<$<BOOL:${BUILD_CAGRA_HNSWLIB}>:src/neighbors/hnsw.cpp>
src/neighbors/ivf_flat_index.cpp
src/neighbors/ivf_flat/ivf_flat_build_extend_float_int64_t.cu
Expand Down
26 changes: 26 additions & 0 deletions cpp/bench/ann/src/cuvs/cuvs_ann_bench_param_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,26 @@ extern template class cuvs::bench::cuvs_cagra<int8_t, uint32_t>;
#include "cuvs_mg_cagra_wrapper.h"
#endif

template <typename ParamT>
void parse_dynamic_batching_params(const nlohmann::json& conf, ParamT& param)
{
if (!conf.value("dynamic_batching", false)) { return; }
param.dynamic_batching = true;
if (conf.contains("dynamic_batching_max_batch_size")) {
param.dynamic_batching_max_batch_size = conf.at("dynamic_batching_max_batch_size");
}
param.dynamic_batching_conservative_dispatch =
conf.value("dynamic_batching_conservative_dispatch", false);
if (conf.contains("dynamic_batching_dispatch_timeout_ms")) {
param.dynamic_batching_dispatch_timeout_ms = conf.at("dynamic_batching_dispatch_timeout_ms");
}
if (conf.contains("dynamic_batching_n_queues")) {
param.dynamic_batching_n_queues = conf.at("dynamic_batching_n_queues");
}
param.dynamic_batching_k =
uint32_t(uint32_t(conf.at("k")) * float(conf.value("refine_ratio", 1.0f)));
}

#if defined(CUVS_ANN_BENCH_USE_CUVS_IVF_FLAT) || defined(CUVS_ANN_BENCH_USE_CUVS_MG)
template <typename T, typename IdxT>
void parse_build_param(const nlohmann::json& conf,
Expand Down Expand Up @@ -138,6 +158,9 @@ void parse_search_param(const nlohmann::json& conf,
param.refine_ratio = conf.at("refine_ratio");
if (param.refine_ratio < 1.0f) { throw std::runtime_error("refine_ratio should be >= 1.0"); }
}

// enable dynamic batching
parse_dynamic_batching_params(conf, param);
}
#endif

Expand Down Expand Up @@ -291,5 +314,8 @@ void parse_search_param(const nlohmann::json& conf,
}
// Same ratio as in IVF-PQ
param.refine_ratio = conf.value("refine_ratio", 1.0f);

// enable dynamic batching
parse_dynamic_batching_params(conf, param);
}
#endif
97 changes: 79 additions & 18 deletions cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <cuvs/distance/distance.hpp>
#include <cuvs/neighbors/cagra.hpp>
#include <cuvs/neighbors/common.hpp>
#include <cuvs/neighbors/dynamic_batching.hpp>
#include <cuvs/neighbors/ivf_pq.hpp>
#include <cuvs/neighbors/nn_descent.hpp>
#include <raft/core/device_mdspan.hpp>
Expand Down Expand Up @@ -63,6 +64,13 @@ class cuvs_cagra : public algo<T>, public algo_gpu {
AllocatorType graph_mem = AllocatorType::kDevice;
AllocatorType dataset_mem = AllocatorType::kDevice;
[[nodiscard]] auto needs_dataset() const -> bool override { return true; }
/* Dynamic batching */
bool dynamic_batching = false;
int64_t dynamic_batching_k;
int64_t dynamic_batching_max_batch_size = 4;
double dynamic_batching_dispatch_timeout_ms = 0.01;
size_t dynamic_batching_n_queues = 8;
bool dynamic_batching_conservative_dispatch = false;
};

struct build_param {
Expand Down Expand Up @@ -173,6 +181,12 @@ class cuvs_cagra : public algo<T>, public algo_gpu {
std::shared_ptr<raft::device_matrix<T, int64_t, raft::row_major>> dataset_;
std::shared_ptr<raft::device_matrix_view<const T, int64_t, raft::row_major>> input_dataset_v_;

std::shared_ptr<cuvs::neighbors::dynamic_batching::index<T, IdxT>> dynamic_batcher_;
cuvs::neighbors::dynamic_batching::search_params dynamic_batcher_sp_{};
int64_t dynamic_batching_max_batch_size_;
size_t dynamic_batching_n_queues_;
bool dynamic_batching_conservative_dispatch_;

inline rmm::device_async_resource_ref get_mr(AllocatorType mem_type)
{
switch (mem_type) {
Expand Down Expand Up @@ -216,26 +230,33 @@ inline auto allocator_to_string(AllocatorType mem_type) -> std::string
template <typename T, typename IdxT>
void cuvs_cagra<T, IdxT>::set_search_param(const search_param_base& param)
{
auto sp = dynamic_cast<const search_param&>(param);
search_params_ = sp.p;
refine_ratio_ = sp.refine_ratio;
auto sp = dynamic_cast<const search_param&>(param);
bool needs_dynamic_batcher_update =
(dynamic_batching_max_batch_size_ != sp.dynamic_batching_max_batch_size) ||
(dynamic_batching_n_queues_ != sp.dynamic_batching_n_queues) ||
(dynamic_batching_conservative_dispatch_ != sp.dynamic_batching_conservative_dispatch);
dynamic_batching_max_batch_size_ = sp.dynamic_batching_max_batch_size;
dynamic_batching_n_queues_ = sp.dynamic_batching_n_queues;
dynamic_batching_conservative_dispatch_ = sp.dynamic_batching_conservative_dispatch;
search_params_ = sp.p;
refine_ratio_ = sp.refine_ratio;
if (sp.graph_mem != graph_mem_) {
// Move graph to correct memory space
graph_mem_ = sp.graph_mem;
RAFT_LOG_DEBUG("moving graph to new memory space: %s", allocator_to_string(graph_mem_).c_str());
// We create a new graph and copy to it from existing graph
auto mr = get_mr(graph_mem_);
auto new_graph = raft::make_device_mdarray<IdxT, int64_t>(
auto mr = get_mr(graph_mem_);
*graph_ = raft::make_device_mdarray<IdxT, int64_t>(
handle_, mr, raft::make_extents<int64_t>(index_->graph().extent(0), index_->graph_degree()));

raft::copy(new_graph.data_handle(),
raft::copy(graph_->data_handle(),
index_->graph().data_handle(),
index_->graph().size(),
raft::resource::get_cuda_stream(handle_));

index_->update_graph(handle_, make_const_mdspan(new_graph.view()));
// update_graph() only stores a view in the index. We need to keep the graph object alive.
*graph_ = std::move(new_graph);
// NB: update_graph() only stores a view in the index. We need to keep the graph object alive.
index_->update_graph(handle_, make_const_mdspan(graph_->view()));
needs_dynamic_batcher_update = true;
}

if (sp.dataset_mem != dataset_mem_ || need_dataset_update_) {
Expand All @@ -256,7 +277,26 @@ void cuvs_cagra<T, IdxT>::set_search_param(const search_param_base& param)
dataset_->data_handle(), dataset_->extent(0), this->dim_, dataset_->extent(1));
index_->update_dataset(handle_, dataset_view);

need_dataset_update_ = false;
need_dataset_update_ = false;
needs_dynamic_batcher_update = true;
}

// dynamic batching
if (sp.dynamic_batching) {
if (!dynamic_batcher_ || needs_dynamic_batcher_update) {
dynamic_batcher_ = std::make_shared<cuvs::neighbors::dynamic_batching::index<T, IdxT>>(
handle_,
cuvs::neighbors::dynamic_batching::index_params{{},
sp.dynamic_batching_k,
sp.dynamic_batching_max_batch_size,
sp.dynamic_batching_n_queues,
sp.dynamic_batching_conservative_dispatch},
*index_,
search_params_);
}
dynamic_batcher_sp_.dispatch_timeout_ms = sp.dynamic_batching_dispatch_timeout_ms;
} else {
if (dynamic_batcher_) { dynamic_batcher_.reset(); }
}
}

Expand Down Expand Up @@ -306,7 +346,7 @@ void cuvs_cagra<T, IdxT>::load(const std::string& file)
template <typename T, typename IdxT>
std::unique_ptr<algo<T>> cuvs_cagra<T, IdxT>::copy()
{
return std::make_unique<cuvs_cagra<T, IdxT>>(*this); // use copy constructor
return std::make_unique<cuvs_cagra<T, IdxT>>(std::cref(*this)); // use copy constructor
}

template <typename T, typename IdxT>
Expand All @@ -330,8 +370,17 @@ void cuvs_cagra<T, IdxT>::search_base(const T* queries,
raft::make_device_matrix_view<IdxT, int64_t>(neighbors_idx_t, batch_size, k);
auto distances_view = raft::make_device_matrix_view<float, int64_t>(distances, batch_size, k);

cuvs::neighbors::cagra::search(
handle_, search_params_, *index_, queries_view, neighbors_view, distances_view);
if (dynamic_batcher_) {
cuvs::neighbors::dynamic_batching::search(handle_,
dynamic_batcher_sp_,
*dynamic_batcher_,
queries_view,
neighbors_view,
distances_view);
} else {
cuvs::neighbors::cagra::search(
handle_, search_params_, *index_, queries_view, neighbors_view, distances_view);
}

if constexpr (sizeof(IdxT) != sizeof(algo_base::index_type)) {
if (raft::get_device_for_address(neighbors) < 0 &&
Expand Down Expand Up @@ -367,11 +416,23 @@ void cuvs_cagra<T, IdxT>::search(
const raft::resources& res = handle_;
auto mem_type =
raft::get_device_for_address(neighbors) >= 0 ? MemoryType::kDevice : MemoryType::kHostPinned;
auto& tmp_buf = get_tmp_buffer_from_global_pool(
((disable_refinement ? 0 : (sizeof(float) + sizeof(algo_base::index_type))) +
(kNeedsIoMapping ? sizeof(IdxT) : 0)) *
batch_size * k0);
auto* candidates_ptr = reinterpret_cast<algo_base::index_type*>(tmp_buf.data(mem_type));

// If dynamic batching is used and there's no sync between benchmark laps, multiple sequential
// requests can group together. The data is copied asynchronously, and if the same intermediate
// buffer is used for multiple requests, they can override each other's data. Hence, we need to
// allocate as much space as required by the maximum number of sequential requests.
auto max_dyn_grouping = dynamic_batcher_ ? raft::div_rounding_up_safe<int64_t>(
dynamic_batching_max_batch_size_, batch_size) *
dynamic_batching_n_queues_
: 1;
auto tmp_buf_size = ((disable_refinement ? 0 : (sizeof(float) + sizeof(algo_base::index_type))) +
(kNeedsIoMapping ? sizeof(IdxT) : 0)) *
batch_size * k0;
auto& tmp_buf = get_tmp_buffer_from_global_pool(tmp_buf_size * max_dyn_grouping);
thread_local static int64_t group_id = 0;
auto* candidates_ptr = reinterpret_cast<algo_base::index_type*>(
reinterpret_cast<uint8_t*>(tmp_buf.data(mem_type)) + tmp_buf_size * group_id);
group_id = (group_id + 1) % max_dyn_grouping;
auto* candidate_dists_ptr =
reinterpret_cast<float*>(candidates_ptr + (disable_refinement ? 0 : batch_size * k0));
auto* neighbors_idx_t =
Expand Down
40 changes: 38 additions & 2 deletions cpp/bench/ann/src/cuvs/cuvs_ivf_pq_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
#include "cuvs_ann_bench_utils.h"

#include <cuvs/distance/distance.hpp>
#include <cuvs/neighbors/dynamic_batching.hpp>
#include <cuvs/neighbors/ivf_pq.hpp>

#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/device_resources.hpp>
Expand All @@ -46,6 +48,13 @@ class cuvs_ivf_pq : public algo<T>, public algo_gpu {
cuvs::neighbors::ivf_pq::search_params pq_param;
float refine_ratio = 1.0f;
[[nodiscard]] auto needs_dataset() const -> bool override { return refine_ratio > 1.0f; }
/* Dynamic batching */
bool dynamic_batching = false;
int64_t dynamic_batching_k;
int64_t dynamic_batching_max_batch_size = 128;
double dynamic_batching_dispatch_timeout_ms = 0.01;
size_t dynamic_batching_n_queues = 3;
bool dynamic_batching_conservative_dispatch = true;
};

using build_param = cuvs::neighbors::ivf_pq::index_params;
Expand Down Expand Up @@ -98,6 +107,9 @@ class cuvs_ivf_pq : public algo<T>, public algo_gpu {
int dimension_;
float refine_ratio_ = 1.0;
raft::device_matrix_view<const T, IdxT> dataset_;

std::shared_ptr<cuvs::neighbors::dynamic_batching::index<T, IdxT>> dynamic_batcher_;
cuvs::neighbors::dynamic_batching::search_params dynamic_batcher_sp_{};
};

template <typename T, typename IdxT>
Expand Down Expand Up @@ -138,6 +150,21 @@ void cuvs_ivf_pq<T, IdxT>::set_search_param(const search_param_base& param)
search_params_ = sp.pq_param;
refine_ratio_ = sp.refine_ratio;
assert(search_params_.n_probes <= index_params_.n_lists);

if (sp.dynamic_batching) {
dynamic_batcher_ = std::make_shared<cuvs::neighbors::dynamic_batching::index<T, IdxT>>(
handle_,
cuvs::neighbors::dynamic_batching::index_params{{},
sp.dynamic_batching_k,
sp.dynamic_batching_max_batch_size,
sp.dynamic_batching_n_queues,
sp.dynamic_batching_conservative_dispatch},
*index_,
search_params_);
dynamic_batcher_sp_.dispatch_timeout_ms = sp.dynamic_batching_dispatch_timeout_ms;
} else {
dynamic_batcher_.reset();
}
}

template <typename T, typename IdxT>
Expand Down Expand Up @@ -168,8 +195,17 @@ void cuvs_ivf_pq<T, IdxT>::search_base(
raft::make_device_matrix_view<IdxT, uint32_t>(neighbors_idx_t, batch_size, k);
auto distances_view = raft::make_device_matrix_view<float, uint32_t>(distances, batch_size, k);

cuvs::neighbors::ivf_pq::search(
handle_, search_params_, *index_, queries_view, neighbors_view, distances_view);
if (dynamic_batcher_) {
cuvs::neighbors::dynamic_batching::search(handle_,
dynamic_batcher_sp_,
*dynamic_batcher_,
queries_view,
neighbors_view,
distances_view);
} else {
cuvs::neighbors::ivf_pq::search(
handle_, search_params_, *index_, queries_view, neighbors_view, distances_view);
}

if constexpr (sizeof(IdxT) != sizeof(algo_base::index_type)) {
raft::linalg::unaryOp(neighbors,
Expand Down
4 changes: 4 additions & 0 deletions cpp/include/cuvs/neighbors/cagra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,10 @@ static_assert(std::is_aggregate_v<search_params>);
*/
template <typename T, typename IdxT>
struct index : cuvs::neighbors::index {
using index_params_type = cagra::index_params;
using search_params_type = cagra::search_params;
using index_type = IdxT;
using value_type = T;
static_assert(!raft::is_narrowing_v<uint32_t, IdxT>,
"IdxT must be able to represent all values of uint32_t");

Expand Down
Loading

0 comments on commit 9fb21ad

Please sign in to comment.