Skip to content

Commit

Permalink
enhance: optimize get norms function (#950)
Browse files Browse the repository at this point in the history
Signed-off-by: cqy123456 <[email protected]>
  • Loading branch information
cqy123456 authored Nov 15, 2024
1 parent d3605fb commit 6b7d756
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 8 deletions.
40 changes: 34 additions & 6 deletions src/common/comp/brute_force.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,12 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
int topk = cfg.k.value();
auto labels = std::make_unique<int64_t[]>(nq * topk);
auto distances = std::make_unique<float[]>(nq * topk);

std::unique_ptr<float[]> norms = nullptr;
if (is_cosine) {
ThreadPool::ScopedSearchOmpSetter setter(1);
norms = std::make_unique<float[]>(nb);
faiss::fvec_norms_L2(norms.get(), (float*)xb, dim, nb);
}
auto pool = ThreadPool::GetGlobalSearchThreadPool();
std::vector<folly::Future<Status>> futs;
futs.reserve(nq);
Expand All @@ -139,7 +144,8 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
if (is_cosine) {
auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim);
faiss::knn_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, &buf, id_selector);
faiss::knn_cosine(copied_query.get(), (const float*)xb, norms.get(), dim, 1, nb, &buf,
id_selector);
} else {
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, id_selector);
}
Expand Down Expand Up @@ -248,6 +254,13 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
auto labels = ids;
auto distances = dis;

std::unique_ptr<float[]> norms = nullptr;
if (is_cosine) {
ThreadPool::ScopedSearchOmpSetter setter(1);
norms = std::make_unique<float[]>(nb);
faiss::fvec_norms_L2(norms.get(), (float*)xb, dim, nb);
}

auto pool = ThreadPool::GetGlobalSearchThreadPool();
std::vector<folly::Future<Status>> futs;
futs.reserve(nq);
Expand All @@ -272,7 +285,8 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
if (is_cosine) {
auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim);
faiss::knn_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, &buf, id_selector);
faiss::knn_cosine(copied_query.get(), (const float*)xb, norms.get(), dim, 1, nb, &buf,
id_selector);
} else {
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, id_selector);
}
Expand Down Expand Up @@ -408,6 +422,13 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
std::vector<std::vector<int64_t>> result_id_array(nq);
std::vector<std::vector<float>> result_dist_array(nq);

std::unique_ptr<float[]> norms = nullptr;
if (is_cosine) {
ThreadPool::ScopedSearchOmpSetter setter(1);
norms = std::make_unique<float[]>(nb);
faiss::fvec_norms_L2(norms.get(), (float*)xb, dim, nb);
}

std::vector<folly::Future<Status>> futs;
futs.reserve(nq);
for (int i = 0; i < nq; ++i) {
Expand Down Expand Up @@ -452,8 +473,8 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
auto cur_query = (const float*)xq + dim * index;
if (is_cosine) {
auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim);
faiss::range_search_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, radius,
&res, id_selector);
faiss::range_search_cosine(copied_query.get(), (const float*)xb, norms.get(), dim, 1, nb,
radius, &res, id_selector);
} else {
faiss::range_search_inner_product(cur_query, (const float*)xb, dim, 1, nb, radius, &res,
id_selector);
Expand Down Expand Up @@ -678,6 +699,13 @@ BruteForce::AnnIterator(const DataSetPtr base_dataset, const DataSetPtr query_da
faiss::MetricType faiss_metric_type = result.value();
bool is_cosine = IsMetricType(metric_str, metric::COSINE);

std::unique_ptr<float[]> norms = nullptr;
if (is_cosine) {
ThreadPool::ScopedSearchOmpSetter setter(1);
norms = std::make_unique<float[]>(nb);
faiss::fvec_norms_L2(norms.get(), (float*)xb, dim, nb);
}

auto pool = ThreadPool::GetGlobalSearchThreadPool();
auto vec = std::vector<IndexNode::IteratorPtr>(nq, nullptr);
std::vector<folly::Future<Status>> futs;
Expand All @@ -703,7 +731,7 @@ BruteForce::AnnIterator(const DataSetPtr base_dataset, const DataSetPtr query_da
auto cur_query = (const float*)xq + dim * index;
if (is_cosine) {
auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim);
faiss::all_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, distances_ids,
faiss::all_cosine(copied_query.get(), (const float*)xb, norms.get(), dim, 1, nb, distances_ids,
id_selector);
} else {
faiss::all_inner_product(cur_query, (const float*)xb, dim, 1, nb, distances_ids, id_selector);
Expand Down
37 changes: 37 additions & 0 deletions src/simd/distances_avx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,43 @@ ivec_L2sqr_avx(const int8_t* x, const int8_t* y, size_t d) {
return res;
}

float
fvec_norm_L2sqr_avx(const float* x, size_t d) {
__m256 msum_0 = _mm256_setzero_ps();
__m256 msum_1 = _mm256_setzero_ps();
while (d >= 16) {
auto mx_0 = _mm256_loadu_ps(x);
auto mx_1 = _mm256_loadu_ps(x + 8);
msum_0 = _mm256_fmadd_ps(mx_0, mx_0, msum_0);
msum_1 = _mm256_fmadd_ps(mx_1, mx_1, msum_1);
x += 16;
d -= 16;
}
msum_0 = msum_0 + msum_1;
if (d >= 8) {
auto mx = _mm256_loadu_ps(x);
msum_0 = _mm256_fmadd_ps(mx, mx, msum_0);
x += 8;
d -= 8;
}
if (d > 0) {
__m128 rest_0 = _mm_setzero_ps();
__m128 rest_1 = _mm_setzero_ps();
if (d >= 4) {
rest_0 = _mm_loadu_ps(x);
x += 4;
d -= 4;
}
if (d >= 0) {
rest_1 = masked_read(d, x);
}
auto mx = _mm256_set_m128(rest_0, rest_1);
msum_0 = _mm256_fmadd_ps(mx, mx, msum_0);
}
auto res = _mm256_reduce_add_ps(msum_0);
return res;
}

float
fp16_vec_norm_L2sqr_avx(const knowhere::fp16* x, size_t d) {
__m256 msum_0 = _mm256_setzero_ps();
Expand Down
3 changes: 3 additions & 0 deletions src/simd/distances_avx.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ ivec_inner_product_avx(const int8_t* x, const int8_t* y, size_t d);
int32_t
ivec_L2sqr_avx(const int8_t* x, const int8_t* y, size_t d);

float
fvec_norm_L2sqr_avx(const float* x, size_t d);

float
fp16_vec_norm_L2sqr_avx(const knowhere::fp16* x, size_t d);

Expand Down
27 changes: 27 additions & 0 deletions src/simd/distances_avx512.cc
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,33 @@ ivec_L2sqr_avx512(const int8_t* x, const int8_t* y, size_t d) {
return res;
}

float
fvec_norm_L2sqr_avx512(const float* x, size_t d) {
__m512 m512_res = _mm512_setzero_ps();
__m512 m512_res_0 = _mm512_setzero_ps();
while (d >= 32) {
auto mx_0 = _mm512_loadu_ps(x);
auto mx_1 = _mm512_loadu_ps(x + 16);
m512_res = _mm512_fmadd_ps(mx_0, mx_0, m512_res);
m512_res_0 = _mm512_fmadd_ps(mx_1, mx_1, m512_res_0);
x += 32;
d -= 32;
}
m512_res = m512_res + m512_res_0;
if (d >= 16) {
auto mx = _mm512_loadu_ps(x);
m512_res = _mm512_fmadd_ps(mx, mx, m512_res);
x += 16;
d -= 16;
}
if (d > 0) {
const __mmask16 mask = (1U << d) - 1U;
auto mx = _mm512_maskz_loadu_ps(mask, x);
m512_res = _mm512_fmadd_ps(mx, mx, m512_res);
}
return _mm512_reduce_add_ps(m512_res);
}

float
fp16_vec_norm_L2sqr_avx512(const knowhere::fp16* x, size_t d) {
__m512 m512_res = _mm512_setzero_ps();
Expand Down
3 changes: 3 additions & 0 deletions src/simd/distances_avx512.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ ivec_inner_product_avx512(const int8_t* x, const int8_t* y, size_t d);
int32_t
ivec_L2sqr_avx512(const int8_t* x, const int8_t* y, size_t d);

float
fvec_norm_L2sqr_avx512(const float* x, size_t d);

float
fp16_vec_norm_L2sqr_avx512(const knowhere::fp16* x, size_t d);

Expand Down
4 changes: 2 additions & 2 deletions src/simd/hook.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ fvec_hook(std::string& simd_type) {
fvec_L1 = fvec_L1_avx512;
fvec_Linf = fvec_Linf_avx512;

fvec_norm_L2sqr = fvec_norm_L2sqr_sse;
fvec_norm_L2sqr = fvec_norm_L2sqr_avx512;
fvec_L2sqr_ny = fvec_L2sqr_ny_sse;
fvec_inner_products_ny = fvec_inner_products_ny_sse;
fvec_madd = fvec_madd_avx512;
Expand Down Expand Up @@ -213,7 +213,7 @@ fvec_hook(std::string& simd_type) {
fvec_L1 = fvec_L1_avx;
fvec_Linf = fvec_Linf_avx;

fvec_norm_L2sqr = fvec_norm_L2sqr_sse;
fvec_norm_L2sqr = fvec_norm_L2sqr_avx;
fvec_L2sqr_ny = fvec_L2sqr_ny_sse;
fvec_inner_products_ny = fvec_inner_products_ny_sse;
fvec_madd = fvec_madd_avx;
Expand Down

0 comments on commit 6b7d756

Please sign in to comment.