Skip to content

Commit

Permalink
Copy and normalize query data when cosine for thread safe
Browse files Browse the repository at this point in the history
Signed-off-by: Yudong Cai <[email protected]>
  • Loading branch information
cydrain committed Aug 15, 2023
1 parent aa4a42f commit 2cf3483
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 64 deletions.
3 changes: 3 additions & 0 deletions include/knowhere/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ NormalizeVecs(float* x, size_t rows, int32_t dim);
extern void
Normalize(const DataSet& dataset);

std::unique_ptr<float[]>
CopyAndNormalizeFloatVec(const float* x, int32_t dim);

constexpr inline uint64_t seed = 0xc70f6907UL;

inline uint64_t
Expand Down
28 changes: 16 additions & 12 deletions src/common/comp/brute_force.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset

std::string metric_str = cfg.metric_type.value();
ASSIGN_OR_RETURN(faiss::MetricType, faiss_metric_type, Str2FaissMetricType(metric_str));
bool is_cosine = IsMetricType(metric_str, metric::COSINE);

int topk = cfg.k.value();
auto labels = new int64_t[nq * topk];
Expand All @@ -65,11 +66,11 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
break;
}
case faiss::METRIC_INNER_PRODUCT: {
auto cur_query = (float*)xq + dim * index;
auto cur_query = (const float*)xq + dim * index;
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
if (IsMetricType(metric_str, metric::COSINE)) {
NormalizeVec(cur_query, dim);
faiss::knn_cosine(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
if (is_cosine) {
auto cur_query_norm = CopyAndNormalizeFloatVec(cur_query, dim);
faiss::knn_cosine(cur_query_norm.get(), (const float*)xb, dim, 1, nb, &buf, bitset);
} else {
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
}
Expand Down Expand Up @@ -131,6 +132,7 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_

std::string metric_str = cfg.metric_type.value();
ASSIGN_OR_RETURN(faiss::MetricType, faiss_metric_type, Str2FaissMetricType(cfg.metric_type.value()));
bool is_cosine = IsMetricType(metric_str, metric::COSINE);

int topk = cfg.k.value();
auto labels = ids;
Expand All @@ -152,11 +154,11 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
break;
}
case faiss::METRIC_INNER_PRODUCT: {
auto cur_query = (float*)xq + dim * index;
auto cur_query = (const float*)xq + dim * index;
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
if (IsMetricType(metric_str, metric::COSINE)) {
NormalizeVec(cur_query, dim);
faiss::knn_cosine(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
if (is_cosine) {
auto cur_query_norm = CopyAndNormalizeFloatVec(cur_query, dim);
faiss::knn_cosine(cur_query_norm.get(), (const float*)xb, dim, 1, nb, &buf, bitset);
} else {
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
}
Expand Down Expand Up @@ -220,6 +222,7 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da

std::string metric_str = cfg.metric_type.value();
ASSIGN_OR_RETURN(faiss::MetricType, faiss_metric_type, Str2FaissMetricType(metric_str));
bool is_cosine = IsMetricType(metric_str, metric::COSINE);

auto radius = cfg.radius.value();
bool is_ip = false;
Expand All @@ -245,10 +248,11 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
}
case faiss::METRIC_INNER_PRODUCT: {
is_ip = true;
auto cur_query = (float*)xq + dim * index;
if (IsMetricType(metric_str, metric::COSINE)) {
NormalizeVec(cur_query, dim);
faiss::range_search_cosine(cur_query, (const float*)xb, dim, 1, nb, radius, &res, bitset);
auto cur_query = (const float*)xq + dim * index;
if (is_cosine) {
auto cur_query_norm = CopyAndNormalizeFloatVec(cur_query, dim);
faiss::range_search_cosine(cur_query_norm.get(), (const float*)xb, dim, 1, nb, radius, &res,
bitset);
} else {
faiss::range_search_inner_product(cur_query, (const float*)xb, dim, 1, nb, radius, &res,
bitset);
Expand Down
9 changes: 9 additions & 0 deletions src/common/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "knowhere/utils.h"

#include <algorithm>
#include <cmath>
#include <cstdint>

Expand Down Expand Up @@ -57,4 +58,12 @@ Normalize(const DataSet& dataset) {
}
}

std::unique_ptr<float[]>
CopyAndNormalizeFloatVec(const float* x, int32_t dim) {
auto x_norm = std::make_unique<float[]>(dim);
std::copy_n(x, dim, x_norm.get());
NormalizeVec(x_norm.get(), dim);
return x_norm;
}

} // namespace knowhere
28 changes: 16 additions & 12 deletions src/index/flat/flat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,7 @@ class FlatIndexNode : public IndexNode {

DataSetPtr results = std::make_shared<DataSet>();
const FlatConfig& f_cfg = static_cast<const FlatConfig&>(cfg);

// do normalize for COSINE metric type
if (IsMetricType(f_cfg.metric_type.value(), knowhere::metric::COSINE)) {
Normalize(dataset);
}
bool is_cosine = IsMetricType(f_cfg.metric_type.value(), knowhere::metric::COSINE);

auto k = f_cfg.k.value();
auto nq = dataset.GetRows();
Expand All @@ -97,7 +93,13 @@ class FlatIndexNode : public IndexNode {
auto cur_ids = ids + k * index;
auto cur_dis = distances + k * index;
if constexpr (std::is_same<T, faiss::IndexFlat>::value) {
index_->search(1, (const float*)x + index * dim, k, cur_dis, cur_ids, bitset);
auto cur_query = (const float*)x + dim * index;
if (is_cosine) {
auto cur_query_norm = CopyAndNormalizeFloatVec(cur_query, dim);
index_->search(1, cur_query_norm.get(), k, cur_dis, cur_ids, bitset);
} else {
index_->search(1, cur_query, k, cur_dis, cur_ids, bitset);
}
}
if constexpr (std::is_same<T, faiss::IndexBinaryFlat>::value) {
auto cur_i_dis = reinterpret_cast<int32_t*>(cur_dis);
Expand Down Expand Up @@ -131,11 +133,7 @@ class FlatIndexNode : public IndexNode {
}

const FlatConfig& f_cfg = static_cast<const FlatConfig&>(cfg);

// do normalize for COSINE metric type
if (IsMetricType(f_cfg.metric_type.value(), knowhere::metric::COSINE)) {
Normalize(dataset);
}
bool is_cosine = IsMetricType(f_cfg.metric_type.value(), knowhere::metric::COSINE);

auto nq = dataset.GetRows();
auto xq = dataset.GetTensor();
Expand All @@ -162,7 +160,13 @@ class FlatIndexNode : public IndexNode {
ThreadPool::ScopedOmpSetter setter(1);
faiss::RangeSearchResult res(1);
if constexpr (std::is_same<T, faiss::IndexFlat>::value) {
index_->range_search(1, (const float*)xq + index * dim, radius, &res, bitset);
auto cur_query = (const float*)xq + dim * index;
if (is_cosine) {
auto cur_query_norm = CopyAndNormalizeFloatVec(cur_query, dim);
index_->range_search(1, cur_query_norm.get(), radius, &res, bitset);
} else {
index_->range_search(1, cur_query, radius, &res, bitset);
}
}
if constexpr (std::is_same<T, faiss::IndexBinaryFlat>::value) {
index_->range_search(1, (const uint8_t*)xq + index * dim / 8, radius, &res, bitset);
Expand Down
4 changes: 2 additions & 2 deletions src/index/hnsw/hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class HnswIndexNode : public IndexNode {
for (int i = 0; i < nq; ++i) {
futs.emplace_back(pool_->push([&, idx = i]() {
auto single_query = (const char*)xq + idx * index_->data_size_;
auto rst = index_->searchKnn((void*)single_query, k, bitset, &param, feder_result);
auto rst = index_->searchKnn(single_query, k, bitset, &param, feder_result);
size_t rst_size = rst.size();
auto p_single_dis = p_dist + idx * k;
auto p_single_id = p_id + idx * k;
Expand Down Expand Up @@ -200,7 +200,7 @@ class HnswIndexNode : public IndexNode {
for (int64_t i = 0; i < nq; ++i) {
futs.emplace_back(pool_->push([&, idx = i]() {
auto single_query = (const char*)xq + idx * index_->data_size_;
auto rst = index_->searchRange((void*)single_query, radius_for_calc, bitset, &param, feder_result);
auto rst = index_->searchRange(single_query, radius_for_calc, bitset, &param, feder_result);
auto elem_cnt = rst.size();
result_dist_array[idx].resize(elem_cnt);
result_id_array[idx].resize(elem_cnt);
Expand Down
77 changes: 53 additions & 24 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -395,11 +395,7 @@ IvfIndexNode<T>::Search(const DataSet& dataset, const Config& cfg, const BitsetV
auto data = dataset.GetTensor();

const IvfConfig& ivf_cfg = static_cast<const IvfConfig&>(cfg);

// do normalize for COSINE metric type
if (IsMetricType(ivf_cfg.metric_type.value(), knowhere::metric::COSINE)) {
Normalize(dataset);
}
bool is_cosine = IsMetricType(ivf_cfg.metric_type.value(), knowhere::metric::COSINE);

auto k = ivf_cfg.k.value();
auto nprobe = ivf_cfg.nprobe.value();
Expand All @@ -423,17 +419,36 @@ IvfIndexNode<T>::Search(const DataSet& dataset, const Config& cfg, const BitsetV
}
}
} else if constexpr (std::is_same<T, faiss::IndexIVFFlat>::value) {
auto cur_data = (const float*)data + index * dim;
index_->search_without_codes_thread_safe(1, cur_data, k, distances + offset, ids + offset, nprobe,
0, bitset);
auto cur_query = (const float*)data + index * dim;
if (is_cosine) {
auto cur_query_norm = CopyAndNormalizeFloatVec(cur_query, dim);
index_->search_without_codes_thread_safe(1, cur_query_norm.get(), k, distances + offset,
ids + offset, nprobe, 0, bitset);
} else {
index_->search_without_codes_thread_safe(1, cur_query, k, distances + offset, ids + offset,
nprobe, 0, bitset);
}
} else if constexpr (std::is_same<T, faiss::IndexScaNN>::value) {
auto cur_data = (const float*)data + index * dim;
auto cur_query = (const float*)data + index * dim;
const ScannConfig& scann_cfg = static_cast<const ScannConfig&>(cfg);
index_->search_thread_safe(1, cur_data, k, distances + offset, ids + offset, nprobe,
scann_cfg.reorder_k.value(), bitset);
if (is_cosine) {
auto cur_query_norm = CopyAndNormalizeFloatVec(cur_query, dim);
index_->search_thread_safe(1, cur_query_norm.get(), k, distances + offset, ids + offset, nprobe,
scann_cfg.reorder_k.value(), bitset);
} else {
index_->search_thread_safe(1, cur_query, k, distances + offset, ids + offset, nprobe,
scann_cfg.reorder_k.value(), bitset);
}
} else {
auto cur_data = (const float*)data + index * dim;
index_->search_thread_safe(1, cur_data, k, distances + offset, ids + offset, nprobe, 0, bitset);
auto cur_query = (const float*)data + index * dim;
if (is_cosine) {
auto cur_query_norm = CopyAndNormalizeFloatVec(cur_query, dim);
index_->search_thread_safe(1, cur_query_norm.get(), k, distances + offset, ids + offset, nprobe,
0, bitset);
} else {
index_->search_thread_safe(1, cur_query, k, distances + offset, ids + offset, nprobe, 0,
bitset);
}
}
}));
}
Expand Down Expand Up @@ -468,11 +483,7 @@ IvfIndexNode<T>::RangeSearch(const DataSet& dataset, const Config& cfg, const Bi
auto dim = dataset.GetDim();

const IvfConfig& ivf_cfg = static_cast<const IvfConfig&>(cfg);

// do normalize for COSINE metric type
if (IsMetricType(ivf_cfg.metric_type.value(), knowhere::metric::COSINE)) {
Normalize(dataset);
}
bool is_cosine = IsMetricType(ivf_cfg.metric_type.value(), knowhere::metric::COSINE);

float radius = ivf_cfg.radius.value();
float range_filter = ivf_cfg.range_filter.value();
Expand All @@ -498,14 +509,32 @@ IvfIndexNode<T>::RangeSearch(const DataSet& dataset, const Config& cfg, const Bi
auto cur_data = (const uint8_t*)xq + index * dim / 8;
index_->range_search_thread_safe(1, cur_data, radius, &res, index_->nlist, bitset);
} else if constexpr (std::is_same<T, faiss::IndexIVFFlat>::value) {
auto cur_data = (const float*)xq + index * dim;
index_->range_search_without_codes_thread_safe(1, cur_data, radius, &res, index_->nlist, 0, bitset);
auto cur_query = (const float*)xq + index * dim;
if (is_cosine) {
auto cur_query_norm = CopyAndNormalizeFloatVec(cur_query, dim);
index_->range_search_without_codes_thread_safe(1, cur_query_norm.get(), radius, &res,
index_->nlist, 0, bitset);
} else {
index_->range_search_without_codes_thread_safe(1, cur_query, radius, &res, index_->nlist, 0,
bitset);
}
} else if constexpr (std::is_same<T, faiss::IndexScaNN>::value) {
auto cur_data = (const float*)xq + index * dim;
index_->range_search_thread_safe(1, cur_data, radius, &res, bitset);
auto cur_query = (const float*)xq + index * dim;
if (is_cosine) {
auto cur_query_norm = CopyAndNormalizeFloatVec(cur_query, dim);
index_->range_search_thread_safe(1, cur_query_norm.get(), radius, &res, bitset);
} else {
index_->range_search_thread_safe(1, cur_query, radius, &res, bitset);
}
} else {
auto cur_data = (const float*)xq + index * dim;
index_->range_search_thread_safe(1, cur_data, radius, &res, index_->nlist, 0, bitset);
auto cur_query = (const float*)xq + index * dim;
if (is_cosine) {
auto cur_query_norm = CopyAndNormalizeFloatVec(cur_query, dim);
index_->range_search_thread_safe(1, cur_query_norm.get(), radius, &res, index_->nlist, 0,
bitset);
} else {
index_->range_search_thread_safe(1, cur_query, radius, &res, index_->nlist, 0, bitset);
}
}
auto elem_cnt = res.lims[1];
result_dist_array[index].resize(elem_cnt);
Expand Down
Loading

0 comments on commit 2cf3483

Please sign in to comment.