Skip to content

Commit

Permalink
Return the error message while failed to search with DiskANN
Browse files Browse the repository at this point in the history
Signed-off-by: yah01 <[email protected]>
  • Loading branch information
yah01 committed Aug 21, 2023
1 parent 57a7fcf commit 8b497b0
Show file tree
Hide file tree
Showing 12 changed files with 167 additions and 139 deletions.
70 changes: 22 additions & 48 deletions include/knowhere/expected.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,6 @@ class expected {
expected(Args&&... args) : val(std::make_optional<T>(std::forward<Args>(args)...)) {
}

expected(const Status& err) : err(err) {
assert(err != Status::success);
}

expected(Status&& err) : err(err) {
assert(err != Status::success);
}

expected(const expected<T>&) = default;

expected(expected<T>&&) noexcept = default;
Expand All @@ -73,8 +65,7 @@ class expected {

Status
error() const {
assert(val.has_value() == false);
return err.value();
return err;
}

const T&
Expand All @@ -100,9 +91,27 @@ class expected {
return *this;
}

static expected<T>
OK() {
return expected(Status::success);
}

static expected<T>
Err(const Status err, std::string msg) {
return expected(err, std::move(msg));
}

private:
// keep these private to avoid creating directly
expected(const Status err) : err(err) {
}

expected(const Status err, std::string msg) : err(err), msg(std::move(msg)) {
assert(err != Status::success);
}

std::optional<T> val = std::nullopt;
std::optional<Status> err = std::nullopt;
Status err;
std::string msg;
};

Expand All @@ -117,49 +126,14 @@ class expected {
} while (0)

template <typename T>
Status
expected<T>
DoAssignOrReturn(T& lhs, const expected<T>& exp) {
if (exp.has_value()) {
lhs = exp.value();
return Status::success;
}
return exp.error();
return exp;
}

#define STATUS_INTERNAL_CONCAT_NAME_INNER(x, y) x##y
#define STATUS_INTERNAL_CONCAT_NAME(x, y) STATUS_INTERNAL_CONCAT_NAME_INNER(x, y)

#define STATUS_INTERNAL_DEPAREN(X) STATUS_INTERNAL_ESC(STATUS_INTERNAL_ISH X)
#define STATUS_INTERNAL_ISH(...) STATUS_INTERNAL_ISH __VA_ARGS__
#define STATUS_INTERNAL_ESC(...) STATUS_INTERNAL_ESC_(__VA_ARGS__)
#define STATUS_INTERNAL_ESC_(...) STATUS_INTERNAL_VAN_STATUS_INTERNAL_##__VA_ARGS__
#define STATUS_INTERNAL_VAN_STATUS_INTERNAL_STATUS_INTERNAL_ISH

#define STATUS_INTERNAL_ASSIGN_OR_RETURN_IMPL(status, lhs, rexpr) \
Status status = knowhere::DoAssignOrReturn(lhs, rexpr); \
if (status != Status::success) { \
return status; \
}

// Evaluates an expression that returns an `expected`. If the expected has a value, assigns
// the value to var. Otherwise returns the error from the current function.
//
// Example: ASSIGN_OR_RETURN(int, i, MaybeInt());
//
// If the type parameter has comma not wrapped by paired parenthesis/double quotes, wrap
// the comma in parenthesis properly.
//
// Examples:
// ASSIGN_OR_RETURN(std::pair<int, int>, pair, MaybePair()); // Not OK
// ASSIGN_OR_RETURN((std::pair<int, int>), pair, MaybePair()); // OK
// ASSIGN_OR_RETURN(std::function<void(int, int)>), fn, MaybeFunction()); // OK
//
// Note that this macro expands into multiple statements and thus cannot be used in a single statement
// such as the body of an if statement without {}.
#define ASSIGN_OR_RETURN(type, var, rexpr) \
STATUS_INTERNAL_DEPAREN(type) var; \
STATUS_INTERNAL_ASSIGN_OR_RETURN_IMPL(STATUS_INTERNAL_CONCAT_NAME(_excepted_, __COUNTER__), var, rexpr)

} // namespace knowhere

#endif /* EXPECTED_H */
39 changes: 32 additions & 7 deletions src/common/comp/brute_force.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "common/metric.h"
#include "common/range_util.h"
#include "faiss/MetricType.h"
#include "faiss/utils/binary_distances.h"
#include "faiss/utils/distances.h"
#include "knowhere/comp/thread_pool.h"
Expand All @@ -40,10 +41,18 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
auto nq = query_dataset->GetRows();

BruteForceConfig cfg;
RETURN_IF_ERROR(Config::Load(cfg, config, knowhere::SEARCH));
std::string msg;
auto status = Config::Load(cfg, config, knowhere::SEARCH, &msg);
if (status != Status::success) {
return expected<DataSetPtr>::Err(status, msg);
}

std::string metric_str = cfg.metric_type.value();
ASSIGN_OR_RETURN(faiss::MetricType, faiss_metric_type, Str2FaissMetricType(metric_str));
auto result = Str2FaissMetricType(metric_str);
if (result.error() != Status::success) {
return expected<DataSetPtr>::Err(result.error(), result.what());
}
faiss::MetricType faiss_metric_type = result.value();
bool is_cosine = IsMetricType(metric_str, metric::COSINE);

int topk = cfg.k.value();
Expand Down Expand Up @@ -112,7 +121,9 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
for (auto& fut : futs) {
fut.wait();
auto ret = fut.result().value();
RETURN_IF_ERROR(ret);
if (ret == Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}
}
return GenResultDataSet(nq, cfg.k.value(), labels, distances);
}
Expand All @@ -131,7 +142,11 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
RETURN_IF_ERROR(Config::Load(cfg, config, knowhere::SEARCH));

std::string metric_str = cfg.metric_type.value();
ASSIGN_OR_RETURN(faiss::MetricType, faiss_metric_type, Str2FaissMetricType(cfg.metric_type.value()));
auto result = Str2FaissMetricType(cfg.metric_type.value());
if (result.error() != Status::success) {
return result.error();
}
faiss::MetricType faiss_metric_type = result.value();
bool is_cosine = IsMetricType(metric_str, metric::COSINE);

int topk = cfg.k.value();
Expand Down Expand Up @@ -218,10 +233,18 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
auto nq = query_dataset->GetRows();

BruteForceConfig cfg;
RETURN_IF_ERROR(Config::Load(cfg, config, knowhere::RANGE_SEARCH));
std::string msg;
auto status = Config::Load(cfg, config, knowhere::RANGE_SEARCH, &msg);
if (status != Status::success) {
return expected<DataSetPtr>::Err(status, std::move(msg));
}

std::string metric_str = cfg.metric_type.value();
ASSIGN_OR_RETURN(faiss::MetricType, faiss_metric_type, Str2FaissMetricType(metric_str));
auto result = Str2FaissMetricType(metric_str);
if (result.error() != Status::success) {
return expected<DataSetPtr>::Err(result.error(), result.what());
}
faiss::MetricType faiss_metric_type = result.value();
bool is_cosine = IsMetricType(metric_str, metric::COSINE);

auto radius = cfg.radius.value();
Expand Down Expand Up @@ -295,7 +318,9 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
for (auto& fut : futs) {
fut.wait();
auto ret = fut.result().value();
RETURN_IF_ERROR(ret);
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}
}

int64_t* ids = nullptr;
Expand Down
27 changes: 18 additions & 9 deletions src/common/index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

#include "knowhere/index.h"

#include "knowhere/dataset.h"
#include "knowhere/expected.h"
#include "knowhere/log.h"

#ifdef NOT_COMPILE_FOR_SWIG
Expand Down Expand Up @@ -65,15 +67,11 @@ Index<T>::Search(const DataSet& dataset, const Json& json, const BitsetView& bit
std::string msg;
const Status load_status = LoadConfig(cfg.get(), json, knowhere::SEARCH, "Search", &msg);
if (load_status != Status::success) {
expected<DataSetPtr> ret(load_status);
ret << msg;
return ret;
return expected<DataSetPtr>::Err(load_status, msg);
}
const Status search_status = cfg->CheckAndAdjustForSearch(&msg);
if (search_status != Status::success) {
expected<DataSetPtr> ret(search_status);
ret << msg;
return ret;
return expected<DataSetPtr>::Err(search_status, msg);
}

#ifdef NOT_COMPILE_FOR_SWIG
Expand All @@ -87,8 +85,15 @@ template <typename T>
inline expected<DataSetPtr>
Index<T>::RangeSearch(const DataSet& dataset, const Json& json, const BitsetView& bitset) const {
auto cfg = this->node->CreateConfig();
RETURN_IF_ERROR(LoadConfig(cfg.get(), json, knowhere::RANGE_SEARCH, "RangeSearch"));
RETURN_IF_ERROR(cfg->CheckAndAdjustForRangeSearch());
std::string msg;
auto status = LoadConfig(cfg.get(), json, knowhere::RANGE_SEARCH, "RangeSearch", &msg);
if (status != Status::success) {
return expected<DataSetPtr>::Err(status, std::move(msg));
}
status = cfg->CheckAndAdjustForRangeSearch();
if (status != Status::success) {
return expected<DataSetPtr>::Err(status, "invalid params for range search");
}

#ifdef NOT_COMPILE_FOR_SWIG
knowhere_range_search_count.Increment();
Expand All @@ -112,7 +117,11 @@ template <typename T>
inline expected<DataSetPtr>
Index<T>::GetIndexMeta(const Json& json) const {
auto cfg = this->node->CreateConfig();
RETURN_IF_ERROR(LoadConfig(cfg.get(), json, knowhere::FEDER, "GetIndexMeta"));
std::string msg;
auto status = LoadConfig(cfg.get(), json, knowhere::FEDER, "GetIndexMeta", &msg);
if (status != Status::success) {
return expected<DataSetPtr>::Err(status, msg);
}
return this->node->GetIndexMeta(*cfg);
}

Expand Down
7 changes: 5 additions & 2 deletions src/common/metric.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <unordered_map>

#include "faiss/MetricType.h"
#include "fmt/format.h"
#include "knowhere/comp/index_param.h"
#include "knowhere/expected.h"

Expand All @@ -36,8 +37,10 @@ Str2FaissMetricType(std::string metric) {

std::transform(metric.begin(), metric.end(), metric.begin(), toupper);
auto it = metric_map.find(metric);
if (it == metric_map.end())
return Status::invalid_metric_type;
if (it == metric_map.end()) {
return expected<faiss::MetricType>::Err(Status::invalid_metric_type,
fmt::format("unsupported metric type {}", metric));
}
return it->second;
}

Expand Down
4 changes: 3 additions & 1 deletion src/common/raft_metric.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <string>
#include <unordered_map>

#include "fmt/format.h"
#include "knowhere/comp/index_param.h"
#include "knowhere/expected.h"
#include "raft/distance/distance_types.hpp"
Expand All @@ -39,7 +40,8 @@ Str2RaftMetricType(std::string metric) {
std::transform(metric.begin(), metric.end(), metric.begin(), toupper);
auto it = metric_map.find(metric);
if (it == metric_map.end())
return Status::invalid_metric_type;
return expected<raft::distance::DistanceType>::Err(Status::invalid_metric_type,
fmt::format("unsupported metric type: {}", metric));
return it->second;
}

Expand Down
Loading

0 comments on commit 8b497b0

Please sign in to comment.