Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

Commit

Permalink
Make HNSW search thread safe (#406)
Browse files Browse the repository at this point in the history
Signed-off-by: liliu-z <[email protected]>

Signed-off-by: liliu-z <[email protected]>
  • Loading branch information
liliu-z authored Aug 18, 2022
1 parent a9b9607 commit f0ecc0d
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 15 deletions.
13 changes: 8 additions & 5 deletions knowhere/index/vector_index/IndexHNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const fais
}
}

index_->setEf(GetIndexParamEf(config));
size_t ef = GetIndexParamEf(config);
hnswlib::SearchParam param{ef};
bool transform = (index_->metric_type_ == 1); // InnerProduct: 1

std::chrono::high_resolution_clock::time_point query_start, query_end;
Expand All @@ -179,10 +180,10 @@ if (CheckKeyInConfig(config, meta::QUERY_THREAD_NUM))
auto single_query = (float*)p_data + i * dim;
std::priority_queue<std::pair<float, hnswlib::labeltype>> rst;
if (STATISTICS_LEVEL >= 3) {
rst = index_->searchKnn(single_query, k, bitset, query_stats[i]);
rst = index_->searchKnn(single_query, k, bitset, query_stats[i], &param);
} else {
auto dummy_stat = hnswlib::StatisticsInfo();
rst = index_->searchKnn(single_query, k, bitset, dummy_stat);
rst = index_->searchKnn(single_query, k, bitset, dummy_stat, &param);
}
size_t rst_size = rst.size();

Expand Down Expand Up @@ -246,7 +247,8 @@ IndexHNSW::QueryByRange(const DatasetPtr& dataset,

auto range_k = GetIndexParamHNSWK(config);
auto radius = GetMetaRadius(config);
index_->setEf(GetIndexParamEf(config));
size_t ef = GetIndexParamEf(config);
hnswlib::SearchParam param{ef};
bool is_IP = (index_->metric_type_ == 1); // InnerProduct: 1

if (!is_IP) {
Expand All @@ -262,7 +264,8 @@ IndexHNSW::QueryByRange(const DatasetPtr& dataset,
auto single_query = (float*)p_data + i * dim;

auto dummy_stat = hnswlib::StatisticsInfo();
auto rst = index_->searchRange(single_query, range_k, (is_IP ? 1.0f - radius : radius), bitset, dummy_stat);
auto rst =
index_->searchRange(single_query, range_k, (is_IP ? 1.0f - radius : radius), bitset, dummy_stat, &param);

for (auto& p : rst) {
result_dist_array[i].push_back(is_IP ? (1 - p.first) : p.first);
Expand Down
16 changes: 10 additions & 6 deletions thirdparty/hnswlib/hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
std::mutex global;
size_t ef_;

// Do not call this to set EF in multi-thread case. This is not thread-safe.
void
setEf(size_t ef) {
ef_ = ef;
Expand Down Expand Up @@ -1111,7 +1112,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
};

std::priority_queue<std::pair<dist_t, labeltype>>
searchKnn(const void* query_data, size_t k, const faiss::BitsetView bitset, StatisticsInfo& stats) const {
searchKnn(const void* query_data, size_t k, const faiss::BitsetView bitset, StatisticsInfo& stats,
const SearchParam* param = nullptr) const {
std::priority_queue<std::pair<dist_t, labeltype>> result;
if (cur_element_count == 0)
return result;
Expand Down Expand Up @@ -1151,10 +1153,11 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {

std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
top_candidates;
size_t ef = param ? param->ef_ : this->ef_;
if (!bitset.empty()) {
top_candidates = searchBaseLayerST<true, true>(currObj, query_data, std::max(ef_, k), bitset, stats);
top_candidates = searchBaseLayerST<true, true>(currObj, query_data, std::max(ef, k), bitset, stats);
} else {
top_candidates = searchBaseLayerST<false, true>(currObj, query_data, std::max(ef_, k), bitset, stats);
top_candidates = searchBaseLayerST<false, true>(currObj, query_data, std::max(ef, k), bitset, stats);
}

while (top_candidates.size() > k) {
Expand All @@ -1170,7 +1173,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {

std::vector<std::pair<dist_t, labeltype>>
searchRange(const void* query_data, size_t range_k, float radius, const faiss::BitsetView bitset,
StatisticsInfo& stats) const {
StatisticsInfo& stats, const SearchParam* param = nullptr) const {
if (cur_element_count == 0) {
return {};
}
Expand Down Expand Up @@ -1207,10 +1210,11 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {

std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
top_candidates;
size_t ef = param ? param->ef_ : this->ef_;
if (!bitset.empty()) {
top_candidates = searchBaseLayerST<true, true>(currObj, query_data, std::max(ef_, range_k), bitset, stats);
top_candidates = searchBaseLayerST<true, true>(currObj, query_data, std::max(ef, range_k), bitset, stats);
} else {
top_candidates = searchBaseLayerST<false, true>(currObj, query_data, std::max(ef_, range_k), bitset, stats);
top_candidates = searchBaseLayerST<false, true>(currObj, query_data, std::max(ef, range_k), bitset, stats);
}

while (top_candidates.size() > range_k) {
Expand Down
13 changes: 9 additions & 4 deletions thirdparty/hnswlib/hnswlib/hnswlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,16 +175,21 @@ class StatisticsInfo {
std::vector<uint32_t> accessed_points_;
};

struct SearchParam {
size_t ef_;
};

template<typename dist_t>
class AlgorithmInterface {
public:
virtual void addPoint(const void *datapoint, labeltype label)=0;

virtual std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void *, size_t, const faiss::BitsetView, hnswlib::StatisticsInfo&) const = 0;
virtual std::priority_queue<std::pair<dist_t, labeltype>>
searchKnn(const void*, size_t, const faiss::BitsetView, hnswlib::StatisticsInfo&, const SearchParam*) const = 0;

virtual std::vector<std::pair<dist_t, labeltype>>
searchRange(const void*, size_t, float, const faiss::BitsetView, hnswlib::StatisticsInfo&) const = 0;
searchRange(const void*, size_t, float, const faiss::BitsetView, hnswlib::StatisticsInfo&,
const SearchParam*) const = 0;

// Return k nearest neighbor in the order of closer fist
virtual std::vector<std::pair<dist_t, labeltype>>
Expand All @@ -202,7 +207,7 @@ AlgorithmInterface<dist_t>::searchKnnCloserFirst(const void* query_data, size_t
std::vector<std::pair<dist_t, labeltype>> result;

// here searchKnn returns the result in the order of further first
auto ret = searchKnn(query_data, k, bitset, stats);
auto ret = searchKnn(query_data, k, bitset, stats, nullptr);
{
size_t sz = ret.size();
result.resize(sz);
Expand Down

0 comments on commit f0ecc0d

Please sign in to comment.