Skip to content

Commit

Permalink
Fix cosine bruteforce
Browse files Browse the repository at this point in the history
Signed-off-by: zh Wang <[email protected]>
  • Loading branch information
hhy3 authored and liliu-z committed Aug 11, 2023
1 parent a4636c3 commit 4f99dc0
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 35 deletions.
57 changes: 22 additions & 35 deletions src/common/comp/brute_force.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,6 @@ class BruteForceConfig : public BaseConfig {};
expected<DataSetPtr>
BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config,
const BitsetView& bitset) {
std::string metric_str = config[meta::METRIC_TYPE].get<std::string>();
bool is_cosine = IsMetricType(metric_str, metric::COSINE);
if (is_cosine) {
Normalize(*base_dataset);
}

auto xb = base_dataset->GetTensor();
auto nb = base_dataset->GetRows();
auto dim = base_dataset->GetDim();
Expand All @@ -48,7 +42,8 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
BruteForceConfig cfg;
RETURN_IF_ERROR(Config::Load(cfg, config, knowhere::SEARCH));

ASSIGN_OR_RETURN(faiss::MetricType, faiss_metric_type, Str2FaissMetricType(cfg.metric_type.value()));
std::string metric_str = cfg.metric_type.value();
ASSIGN_OR_RETURN(faiss::MetricType, faiss_metric_type, Str2FaissMetricType(metric_str));

int topk = cfg.k.value();
auto labels = new int64_t[nq * topk];
Expand All @@ -71,11 +66,13 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
}
case faiss::METRIC_INNER_PRODUCT: {
auto cur_query = (float*)xq + dim * index;
if (is_cosine) {
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
if (IsMetricType(metric_str, metric::COSINE)) {
NormalizeVec(cur_query, dim);
faiss::knn_cosine(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
} else {
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
}
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
break;
}
case faiss::METRIC_Jaccard: {
Expand Down Expand Up @@ -122,12 +119,6 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
Status
BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_dataset, int64_t* ids, float* dis,
const Json& config, const BitsetView& bitset) {
std::string metric_str = config[meta::METRIC_TYPE].get<std::string>();
bool is_cosine = IsMetricType(metric_str, metric::COSINE);
if (is_cosine) {
Normalize(*base_dataset);
}

auto xb = base_dataset->GetTensor();
auto nb = base_dataset->GetRows();
auto dim = base_dataset->GetDim();
Expand All @@ -138,18 +129,13 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
BruteForceConfig cfg;
RETURN_IF_ERROR(Config::Load(cfg, config, knowhere::SEARCH));

auto metric_type = Str2FaissMetricType(cfg.metric_type.value());
if (!metric_type.has_value()) {
LOG_KNOWHERE_ERROR_ << "Invalid metric type: " << cfg.metric_type.value();
return Status::invalid_metric_type;
}
std::string metric_str = cfg.metric_type.value();
ASSIGN_OR_RETURN(faiss::MetricType, faiss_metric_type, Str2FaissMetricType(cfg.metric_type.value()));

int topk = cfg.k.value();
auto labels = ids;
auto distances = dis;

auto faiss_metric_type = metric_type.value();

auto pool = ThreadPool::GetGlobalThreadPool();
std::vector<folly::Future<Status>> futs;
futs.reserve(nq);
Expand All @@ -167,11 +153,13 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
}
case faiss::METRIC_INNER_PRODUCT: {
auto cur_query = (float*)xq + dim * index;
if (is_cosine) {
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
if (IsMetricType(metric_str, metric::COSINE)) {
NormalizeVec(cur_query, dim);
faiss::knn_cosine(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
} else {
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
}
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
break;
}
case faiss::METRIC_Jaccard: {
Expand Down Expand Up @@ -220,12 +208,6 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
expected<DataSetPtr>
BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config,
const BitsetView& bitset) {
std::string metric_str = config[meta::METRIC_TYPE].get<std::string>();
bool is_cosine = IsMetricType(metric_str, metric::COSINE);
if (is_cosine) {
Normalize(*base_dataset);
}

auto xb = base_dataset->GetTensor();
auto nb = base_dataset->GetRows();
auto dim = base_dataset->GetDim();
Expand All @@ -236,11 +218,13 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
BruteForceConfig cfg;
RETURN_IF_ERROR(Config::Load(cfg, config, knowhere::RANGE_SEARCH));

std::string metric_str = cfg.metric_type.value();
ASSIGN_OR_RETURN(faiss::MetricType, faiss_metric_type, Str2FaissMetricType(metric_str));

auto radius = cfg.radius.value();
bool is_ip = false;
float range_filter = cfg.range_filter.value();

ASSIGN_OR_RETURN(faiss::MetricType, faiss_metric_type, Str2FaissMetricType(cfg.metric_type.value()));
auto pool = ThreadPool::GetGlobalThreadPool();

std::vector<std::vector<int64_t>> result_id_array(nq);
Expand All @@ -262,10 +246,13 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
case faiss::METRIC_INNER_PRODUCT: {
is_ip = true;
auto cur_query = (float*)xq + dim * index;
if (is_cosine) {
if (IsMetricType(metric_str, metric::COSINE)) {
NormalizeVec(cur_query, dim);
faiss::range_search_cosine(cur_query, (const float*)xb, dim, 1, nb, radius, &res, bitset);
} else {
faiss::range_search_inner_product(cur_query, (const float*)xb, dim, 1, nb, radius, &res,
bitset);
}
faiss::range_search_inner_product(cur_query, (const float*)xb, dim, 1, nb, radius, &res, bitset);
break;
}
case faiss::METRIC_Jaccard: {
Expand Down
154 changes: 154 additions & 0 deletions thirdparty/faiss/faiss/utils/distances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <cmath>
#include <cstdio>
#include <cstring>
#include "simd/hook.h"

#include <omp.h>

Expand Down Expand Up @@ -284,6 +285,44 @@ void exhaustive_L2sqr_seq(
}
}

namespace {
float fvec_cosine(const float* x, const float* y, size_t d) {
return fvec_inner_product(x, y, d) / sqrtf(fvec_norm_L2sqr(y, d));
}
} // namespace

template <class ResultHandler>
void exhaustive_cosine_seq(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
ResultHandler& res,
const BitsetView bitset) {
using SingleResultHandler = typename ResultHandler::SingleResultHandler;
int nt = std::min(int(nx), omp_get_max_threads());

#pragma omp parallel num_threads(nt)
{
SingleResultHandler resi(res);
#pragma omp for
for (int64_t i = 0; i < nx; i++) {
const float* x_i = x + i * d;
const float* y_j = y;
resi.begin(i);
for (size_t j = 0; j < ny; j++) {
if (bitset.empty() || !bitset.test(j)) {
float disij = fvec_cosine(x_i, y_j, d);
resi.add_result(disij, j);
}
y_j += d;
}
resi.end();
}
}
}

/** Find the nearest neighbors for nx queries in a set of ny vectors */
template <class ResultHandler>
void exhaustive_inner_product_blas(
Expand Down Expand Up @@ -426,6 +465,76 @@ void exhaustive_L2sqr_blas(
}
}

template <class ResultHandler>
void exhaustive_cosine_blas(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
ResultHandler& res,
const BitsetView bitset = nullptr) {
// BLAS does not like empty matrices
if (nx == 0 || ny == 0)
return;

/* block sizes */
const size_t bs_x = distance_compute_blas_query_bs;
const size_t bs_y = distance_compute_blas_database_bs;
// const size_t bs_x = 16, bs_y = 16;
std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
std::unique_ptr<float[]> y_norms(new float[nx]);
std::unique_ptr<float[]> del2;

fvec_norms_L2(y_norms.get(), x, d, nx);

for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
size_t i1 = i0 + bs_x;
if (i1 > nx)
i1 = nx;

res.begin_multiple(i0, i1);

for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
size_t j1 = j0 + bs_y;
if (j1 > ny)
j1 = ny;
/* compute the actual dot products */
{
float one = 1, zero = 0;
FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
sgemm_("Transpose",
"Not transpose",
&nyi,
&nxi,
&di,
&one,
y + j0 * d,
&di,
x + i0 * d,
&di,
&zero,
ip_block.get(),
&nyi);
}
#pragma omp parallel for
for (int64_t i = i0; i < i1; i++) {
float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);

for (size_t j = j0; j < j1; j++) {
float ip = *ip_line;
float dis = ip / y_norms[j];
*ip_line = dis;
ip_line++;
}
}
res.add_results(j0, j1, ip_block.get(), bitset);
}
res.end_multiple();
InterruptCallback::check();
}
}

template <class DistanceCorrection, class ResultHandler>
static void knn_jaccard_blas(
const float* x,
Expand Down Expand Up @@ -577,6 +686,34 @@ void knn_L2sqr(
}
}

void knn_cosine(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
float_minheap_array_t* ha,
const BitsetView bitset) {
if (ha->k < distance_compute_min_k_reservoir) {
HeapResultHandler<CMin<float, int64_t>> res(
ha->nh, ha->val, ha->ids, ha->k);
if (nx < distance_compute_blas_threshold) {
exhaustive_L2sqr_IP_seq(x, y, d, nx, ny, res, fvec_cosine, bitset);
} else {
exhaustive_cosine_blas(x, y, d, nx, ny, res, bitset);
}
} else {
ReservoirResultHandler<CMin<float, int64_t>> res(
ha->nh, ha->val, ha->ids, ha->k);
if (nx < distance_compute_blas_threshold) {
exhaustive_L2sqr_IP_seq(
x, y, d, nx, ny, res, fvec_inner_product, bitset);
} else {
exhaustive_cosine_blas(x, y, d, nx, ny, res, bitset);
}
}
}

struct NopDistanceCorrection {
float operator()(float dis, size_t /*qno*/, size_t /*bno*/) const {
return dis;
Expand Down Expand Up @@ -640,6 +777,23 @@ void range_search_inner_product(
}
}

void range_search_cosine(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
float radius,
RangeSearchResult* res,
const BitsetView bitset) {
RangeSearchResultHandler<CMin<float, int64_t>> resh(res, radius);
if (nx < distance_compute_blas_threshold) {
exhaustive_cosine_seq(x, y, d, nx, ny, resh, bitset);
} else {
exhaustive_cosine_blas(x, y, d, nx, ny, resh, bitset);
}
}

/***************************************************************************
* compute a subset of distances
***************************************************************************/
Expand Down
19 changes: 19 additions & 0 deletions thirdparty/faiss/faiss/utils/distances.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,15 @@ void knn_L2sqr(
const float* y_norm2 = nullptr,
const BitsetView bitset = nullptr);

void knn_cosine(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
float_minheap_array_t* ha,
const BitsetView bitset);

void knn_jaccard(
const float* x,
const float* y,
Expand Down Expand Up @@ -265,6 +274,16 @@ void range_search_inner_product(
RangeSearchResult* result,
const BitsetView bitset = nullptr);

void range_search_cosine(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
float radius,
RangeSearchResult* result,
const BitsetView bitset = nullptr);

/***************************************************************************
* PQ tables computations
***************************************************************************/
Expand Down

0 comments on commit 4f99dc0

Please sign in to comment.