Skip to content

Commit

Permalink
Divide the thread pool into build and search two kind
Browse files Browse the repository at this point in the history
Signed-off-by: cqy123456 <[email protected]>
  • Loading branch information
cqy123456 authored and liliu-z committed Aug 17, 2023
1 parent 721734f commit d7288bf
Show file tree
Hide file tree
Showing 13 changed files with 96 additions and 65 deletions.
61 changes: 44 additions & 17 deletions include/knowhere/comp/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,44 +55,70 @@ class ThreadPool {
}

/**
* @brief Set the threads number to the global thread pool of knowhere
* @brief Set the threads number to the global build thread pool of knowhere
*
* @param num_threads
*/
static void
InitGlobalThreadPool(uint32_t num_threads) {
InitThreadPool(uint32_t num_threads, uint32_t& thread_pool_size) {
if (num_threads <= 0) {
LOG_KNOWHERE_ERROR_ << "num_threads should be bigger than 0";
return;
}

if (global_thread_pool_size_ == 0) {
if (thread_pool_size == 0) {
std::lock_guard<std::mutex> lock(global_thread_pool_mutex_);
if (global_thread_pool_size_ == 0) {
global_thread_pool_size_ = num_threads;
if (thread_pool_size == 0) {
thread_pool_size = num_threads;
return;
}
}
LOG_KNOWHERE_WARNING_ << "Global ThreadPool has already been initialized with threads num: "
<< global_thread_pool_size_;
}

static void
InitGlobalBuildThreadPool(uint32_t num_threads) {
InitThreadPool(num_threads, global_build_thread_pool_size_);
LOG_KNOWHERE_WARNING_ << "Global Build ThreadPool has already been initialized with threads num: "
<< global_build_thread_pool_size_;
}

/**
* @brief Set the threads number to the global search thread pool of knowhere
*
* @param num_threads
*/
static void
InitGlobalSearchThreadPool(uint32_t num_threads) {
InitThreadPool(num_threads, global_search_thread_pool_size_);
LOG_KNOWHERE_WARNING_ << "Global Search ThreadPool has already been initialized with threads num: "
<< global_search_thread_pool_size_;
}

/**
* @brief Get the global thread pool of knowhere.
*
* @return ThreadPool&
*/

static std::shared_ptr<ThreadPool>
GetGlobalThreadPool() {
if (global_thread_pool_size_ == 0) {
std::lock_guard<std::mutex> lock(global_thread_pool_mutex_);
if (global_thread_pool_size_ == 0) {
global_thread_pool_size_ = std::thread::hardware_concurrency();
LOG_KNOWHERE_WARNING_ << "Global ThreadPool has not been initialized yet, init it with threads num: "
<< global_thread_pool_size_;
}
GetGlobalBuildThreadPool() {
if (global_build_thread_pool_size_ == 0) {
InitThreadPool(std::thread::hardware_concurrency(), global_build_thread_pool_size_);
LOG_KNOWHERE_WARNING_ << "Global Search ThreadPool has not been initialized yet, init it with threads num: "
<< global_search_thread_pool_size_;
}
static auto pool = std::make_shared<ThreadPool>(global_build_thread_pool_size_);
return pool;
}

static std::shared_ptr<ThreadPool>
GetGlobalSearchThreadPool() {
if (global_search_thread_pool_size_ == 0) {
InitThreadPool(std::thread::hardware_concurrency(), global_search_thread_pool_size_);
LOG_KNOWHERE_WARNING_ << "Global Search ThreadPool has not been initialized yet, init it with threads num: "
<< global_search_thread_pool_size_;
}
static auto pool = std::make_shared<ThreadPool>(global_thread_pool_size_);
static auto pool = std::make_shared<ThreadPool>(global_search_thread_pool_size_);
return pool;
}

Expand All @@ -110,7 +136,8 @@ class ThreadPool {

private:
folly::CPUThreadPoolExecutor pool_;
inline static uint32_t global_thread_pool_size_ = 0;
inline static uint32_t global_build_thread_pool_size_ = 0;
inline static uint32_t global_search_thread_pool_size_ = 0;
inline static std::mutex global_thread_pool_mutex_;
constexpr static size_t kTaskQueueFactor = 16;
};
Expand Down
6 changes: 3 additions & 3 deletions src/common/comp/brute_force.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
auto labels = new int64_t[nq * topk];
auto distances = new float[nq * topk];

auto pool = ThreadPool::GetGlobalThreadPool();
auto pool = ThreadPool::GetGlobalSearchThreadPool();
std::vector<folly::Future<Status>> futs;
futs.reserve(nq);
for (int i = 0; i < nq; ++i) {
Expand Down Expand Up @@ -138,7 +138,7 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
auto labels = ids;
auto distances = dis;

auto pool = ThreadPool::GetGlobalThreadPool();
auto pool = ThreadPool::GetGlobalSearchThreadPool();
std::vector<folly::Future<Status>> futs;
futs.reserve(nq);
for (int i = 0; i < nq; ++i) {
Expand Down Expand Up @@ -228,7 +228,7 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
bool is_ip = false;
float range_filter = cfg.range_filter.value();

auto pool = ThreadPool::GetGlobalThreadPool();
auto pool = ThreadPool::GetGlobalSearchThreadPool();

std::vector<std::vector<int64_t>> result_id_array(nq);
std::vector<std::vector<float>> result_dist_array(nq);
Expand Down
12 changes: 6 additions & 6 deletions src/index/diskann/diskann.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class DiskANNIndexNode : public IndexNode {
std::unique_ptr<diskann::PQFlashIndex<T>> pq_flash_index_;
std::atomic_int64_t dim_;
std::atomic_int64_t count_;
std::shared_ptr<ThreadPool> pool_;
std::shared_ptr<ThreadPool> search_pool_;
};

} // namespace knowhere
Expand Down Expand Up @@ -372,7 +372,7 @@ DiskANNIndexNode<T>::Deserialize(const BinarySet& binset, const Config& cfg) {
}

// set thread pool
pool_ = ThreadPool::GetGlobalThreadPool();
search_pool_ = ThreadPool::GetGlobalSearchThreadPool();

// load diskann pq code and meta info
std::shared_ptr<AlignedFileReader> reader = nullptr;
Expand All @@ -381,7 +381,7 @@ DiskANNIndexNode<T>::Deserialize(const BinarySet& binset, const Config& cfg) {

pq_flash_index_ = std::make_unique<diskann::PQFlashIndex<T>>(reader, diskann_metric);
auto disk_ann_call = [&]() {
int res = pq_flash_index_->load(pool_->size(), index_prefix_.c_str());
int res = pq_flash_index_->load(search_pool_->size(), index_prefix_.c_str());
if (res != 0) {
throw diskann::ANNException("pq_flash_index_->load returned non-zero value: " + std::to_string(res), -1);
}
Expand Down Expand Up @@ -472,7 +472,7 @@ DiskANNIndexNode<T>::Deserialize(const BinarySet& binset, const Config& cfg) {
std::vector<folly::Future<folly::Unit>> futures;
futures.reserve(warmup_num);
for (_s64 i = 0; i < (int64_t)warmup_num; ++i) {
futures.emplace_back(pool_->push([&, index = i]() {
futures.emplace_back(search_pool_->push([&, index = i]() {
pq_flash_index_->cached_beam_search(warmup + (index * warmup_aligned_dim), 1, warmup_L,
warmup_result_ids_64.data() + (index * 1),
warmup_result_dists.data() + (index * 1), 4);
Expand Down Expand Up @@ -542,7 +542,7 @@ DiskANNIndexNode<T>::Search(const DataSet& dataset, const Config& cfg, const Bit
std::vector<folly::Future<folly::Unit>> futures;
futures.reserve(nq);
for (int64_t row = 0; row < nq; ++row) {
futures.emplace_back(pool_->push([&, index = row]() {
futures.emplace_back(search_pool_->push([&, index = row]() {
pq_flash_index_->cached_beam_search(xq + (index * dim), k, lsearch, p_id + (index * k),
p_dist + (index * k), beamwidth, false, nullptr, feder_result, bitset,
filter_ratio, for_tuning);
Expand Down Expand Up @@ -612,7 +612,7 @@ DiskANNIndexNode<T>::RangeSearch(const DataSet& dataset, const Config& cfg, cons
futures.reserve(nq);
bool all_searches_are_good = true;
for (int64_t row = 0; row < nq; ++row) {
futures.emplace_back(pool_->push([&, index = row]() {
futures.emplace_back(search_pool_->push([&, index = row]() {
std::vector<int64_t> indices;
std::vector<float> distances;
pq_flash_index_->range_search(xq + (index * dim), radius, min_k, max_k, result_id_array[index],
Expand Down
8 changes: 4 additions & 4 deletions src/index/flat/flat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class FlatIndexNode : public IndexNode {
FlatIndexNode(const Object&) : index_(nullptr) {
static_assert(std::is_same<T, faiss::IndexFlat>::value || std::is_same<T, faiss::IndexBinaryFlat>::value,
"not support");
pool_ = ThreadPool::GetGlobalThreadPool();
search_pool_ = ThreadPool::GetGlobalSearchThreadPool();
}

Status
Expand Down Expand Up @@ -88,7 +88,7 @@ class FlatIndexNode : public IndexNode {
std::vector<folly::Future<folly::Unit>> futs;
futs.reserve(nq);
for (int i = 0; i < nq; ++i) {
futs.emplace_back(pool_->push([&, index = i] {
futs.emplace_back(search_pool_->push([&, index = i] {
ThreadPool::ScopedOmpSetter setter(1);
auto cur_ids = ids + k * index;
auto cur_dis = distances + k * index;
Expand Down Expand Up @@ -156,7 +156,7 @@ class FlatIndexNode : public IndexNode {
std::vector<folly::Future<folly::Unit>> futs;
futs.reserve(nq);
for (int i = 0; i < nq; ++i) {
futs.emplace_back(pool_->push([&, index = i] {
futs.emplace_back(search_pool_->push([&, index = i] {
ThreadPool::ScopedOmpSetter setter(1);
faiss::RangeSearchResult res(1);
if constexpr (std::is_same<T, faiss::IndexFlat>::value) {
Expand Down Expand Up @@ -348,7 +348,7 @@ class FlatIndexNode : public IndexNode {

private:
std::unique_ptr<T> index_;
std::shared_ptr<ThreadPool> pool_;
std::shared_ptr<ThreadPool> search_pool_;
};

KNOWHERE_REGISTER_GLOBAL(FLAT,
Expand Down
8 changes: 4 additions & 4 deletions src/index/hnsw/hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace knowhere {
class HnswIndexNode : public IndexNode {
public:
HnswIndexNode(const Object& object) : index_(nullptr) {
pool_ = ThreadPool::GetGlobalThreadPool();
search_pool_ = ThreadPool::GetGlobalSearchThreadPool();
}

Status
Expand Down Expand Up @@ -124,7 +124,7 @@ class HnswIndexNode : public IndexNode {
std::vector<folly::Future<folly::Unit>> futs;
futs.reserve(nq);
for (int i = 0; i < nq; ++i) {
futs.emplace_back(pool_->push([&, idx = i]() {
futs.emplace_back(search_pool_->push([&, idx = i]() {
auto single_query = (const char*)xq + idx * index_->data_size_;
auto rst = index_->searchKnn(single_query, k, bitset, &param, feder_result);
size_t rst_size = rst.size();
Expand Down Expand Up @@ -198,7 +198,7 @@ class HnswIndexNode : public IndexNode {
std::vector<folly::Future<folly::Unit>> futs;
futs.reserve(nq);
for (int64_t i = 0; i < nq; ++i) {
futs.emplace_back(pool_->push([&, idx = i]() {
futs.emplace_back(search_pool_->push([&, idx = i]() {
auto single_query = (const char*)xq + idx * index_->data_size_;
auto rst = index_->searchRange(single_query, radius_for_calc, bitset, &param, feder_result);
auto elem_cnt = rst.size();
Expand Down Expand Up @@ -440,7 +440,7 @@ class HnswIndexNode : public IndexNode {

private:
hnswlib::HierarchicalNSW<float>* index_;
std::shared_ptr<ThreadPool> pool_;
std::shared_ptr<ThreadPool> search_pool_;
};

KNOWHERE_REGISTER_GLOBAL(HNSW, [](const Object& object) { return Index<HnswIndexNode>::Create(object); });
Expand Down
8 changes: 4 additions & 4 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class IvfIndexNode : public IndexNode {
std::is_same<T, faiss::IndexIVFScalarQuantizer>::value ||
std::is_same<T, faiss::IndexBinaryIVF>::value || std::is_same<T, faiss::IndexScaNN>::value,
"not support");
pool_ = ThreadPool::GetGlobalThreadPool();
search_pool_ = ThreadPool::GetGlobalSearchThreadPool();
}
Status
Train(const DataSet& dataset, const Config& cfg) override;
Expand Down Expand Up @@ -196,7 +196,7 @@ class IvfIndexNode : public IndexNode {

private:
std::unique_ptr<T> index_;
std::shared_ptr<ThreadPool> pool_;
std::shared_ptr<ThreadPool> search_pool_;
};

} // namespace knowhere
Expand Down Expand Up @@ -407,7 +407,7 @@ IvfIndexNode<T>::Search(const DataSet& dataset, const Config& cfg, const BitsetV
std::vector<folly::Future<folly::Unit>> futs;
futs.reserve(rows);
for (int i = 0; i < rows; ++i) {
futs.emplace_back(pool_->push([&, index = i] {
futs.emplace_back(search_pool_->push([&, index = i] {
ThreadPool::ScopedOmpSetter setter(1);
auto offset = k * index;
std::unique_ptr<float[]> copied_query = nullptr;
Expand Down Expand Up @@ -496,7 +496,7 @@ IvfIndexNode<T>::RangeSearch(const DataSet& dataset, const Config& cfg, const Bi
std::vector<folly::Future<folly::Unit>> futs;
futs.reserve(nq);
for (int i = 0; i < nq; ++i) {
futs.emplace_back(pool_->push([&, index = i] {
futs.emplace_back(search_pool_->push([&, index = i] {
ThreadPool::ScopedOmpSetter setter(1);
faiss::RangeSearchResult res(1);
std::unique_ptr<float[]> copied_query = nullptr;
Expand Down
3 changes: 2 additions & 1 deletion thirdparty/DiskANN/include/diskann/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ namespace diskann {
size_t _data_len;
size_t _neighbor_len;

std::shared_ptr<knowhere::ThreadPool> _thread_pool;
std::shared_ptr<knowhere::ThreadPool> _build_thread_pool;
std::shared_ptr<knowhere::ThreadPool> _search_thread_pool;
};
} // namespace diskann
4 changes: 2 additions & 2 deletions thirdparty/DiskANN/src/aux_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ namespace diskann {
return;
}

auto thread_pool = knowhere::ThreadPool::GetGlobalThreadPool();
auto thread_pool = knowhere::ThreadPool::GetGlobalBuildThreadPool();

auto points_num = graph.size();
if (num_nodes_to_cache >= points_num) {
Expand Down Expand Up @@ -815,7 +815,7 @@ namespace diskann {
uint32_t best_bw = start_bw;
bool stop_flag = false;

auto thread_pool = knowhere::ThreadPool::GetGlobalThreadPool();
auto thread_pool = knowhere::ThreadPool::GetGlobalBuildThreadPool();

while (!stop_flag) {
std::vector<int64_t> tuning_sample_result_ids_64(tuning_sample_num, 0);
Expand Down
Loading

0 comments on commit d7288bf

Please sign in to comment.