Skip to content

Commit

Permalink
Iterator for diskann (#874)
Browse files Browse the repository at this point in the history
* diskann iterator

diskann iterator for knowhere (opensource)

Signed-off-by: min.tian <[email protected]>

remove diskann-range_search, use index_node-range_search instead

Signed-off-by: min.tian <[email protected]>

pre-commit style

Signed-off-by: min.tian <[email protected]>

knowhere-diskann iterator fix: skip if query_norm == 0; remove range_search test - min_k < max_k

Signed-off-by: min.tian <[email protected]>

* fix

Signed-off-by: min.tian <[email protected]>

* remove config - use_reorder_data, for_tuning

Signed-off-by: min.tian <[email protected]>

---------

Signed-off-by: min.tian <[email protected]>
Co-authored-by: liziheng <[email protected]>
  • Loading branch information
alwayslove2013 and xxxlzhxxx authored Oct 10, 2024
1 parent b201140 commit 690ade0
Show file tree
Hide file tree
Showing 5 changed files with 552 additions and 136 deletions.
2 changes: 0 additions & 2 deletions include/knowhere/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,6 @@ class BaseConfig : public Config {
CFG_BOOL trace_visit;
CFG_BOOL enable_mmap;
CFG_BOOL enable_mmap_pop;
CFG_BOOL for_tuning;
CFG_BOOL shuffle_build;
CFG_STRING trace_id;
CFG_STRING span_id;
Expand Down Expand Up @@ -602,7 +601,6 @@ class BaseConfig : public Config {
.description("enable map_populate option for mmap")
.for_deserialize()
.for_deserialize_from_file();
KNOWHERE_CONFIG_DECLARE_FIELD(for_tuning).set_default(false).description("for tuning").for_search();
KNOWHERE_CONFIG_DECLARE_FIELD(shuffle_build)
.set_default(true)
.description("shuffle ids before index building")
Expand Down
144 changes: 81 additions & 63 deletions src/index/diskann/diskann.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,6 @@ class DiskANNIndexNode : public IndexNode {
expected<DataSetPtr>
Search(const DataSetPtr dataset, std::unique_ptr<Config> cfg, const BitsetView& bitset) const override;

expected<DataSetPtr>
RangeSearch(const DataSetPtr dataset, std::unique_ptr<Config> cfg, const BitsetView& bitset) const override;

expected<DataSetPtr>
GetVectorByIds(const DataSetPtr dataset) const override;

Expand Down Expand Up @@ -152,7 +149,39 @@ class DiskANNIndexNode : public IndexNode {
return knowhere::IndexEnum::INDEX_DISKANN;
}

expected<std::vector<IndexNode::IteratorPtr>>
AnnIterator(const DataSetPtr dataset, std::unique_ptr<Config> cfg, const BitsetView& bitset) const override;

private:
class iterator : public IndexIterator {
public:
iterator(const bool transform, const DataType* query_data, const uint64_t lsearch, const uint64_t beam_width,
const float filter_ratio, const knowhere::BitsetView& bitset, diskann::PQFlashIndex<DataType>* index)
: IndexIterator(transform),
index_(index),
transform_(transform),
workspace_(index_->getIteratorWorkspace(query_data, lsearch, beam_width, filter_ratio, bitset)) {
}

protected:
void
next_batch(std::function<void(const std::vector<DistId>&)> batch_handler) override {
index_->getIteratorNextBatch(workspace_.get());
if (transform_) {
for (auto& p : workspace_->backup_res) {
p.val = -p.val;
}
}
batch_handler(workspace_->backup_res);
workspace_->backup_res.clear();
}

private:
diskann::PQFlashIndex<DataType>* index_;
const bool transform_;
std::unique_ptr<diskann::IteratorWorkspace<DataType>> workspace_;
};

bool
LoadFile(const std::string& filename) {
if (!file_manager_->LoadFile(filename)) {
Expand Down Expand Up @@ -520,6 +549,55 @@ DiskANNIndexNode<DataType>::Deserialize(const BinarySet& binset, std::shared_ptr
return Status::success;
}

template <typename DataType>
expected<std::vector<IndexNode::IteratorPtr>>
DiskANNIndexNode<DataType>::AnnIterator(const DataSetPtr dataset, std::unique_ptr<Config> cfg,
const BitsetView& bitset) const {
if (!is_prepared_.load() || !pq_flash_index_) {
LOG_KNOWHERE_ERROR_ << "Failed to load diskann.";
return expected<std::vector<std::shared_ptr<IndexNode::iterator>>>::Err(Status::empty_index,
"DiskANN not loaded");
}

auto search_conf = static_cast<const DiskANNConfig&>(*cfg);
if (!CheckMetric(search_conf.metric_type.value())) {
return expected<std::vector<std::shared_ptr<IndexNode::iterator>>>::Err(Status::invalid_metric_type,
"unsupported metric type");
}

constexpr uint64_t k_lsearch_iterator = 32;
auto lsearch = static_cast<uint64_t>(search_conf.search_list_size.value_or(k_lsearch_iterator));
auto beamwidth = static_cast<uint64_t>(search_conf.beamwidth.value());
auto filter_ratio = static_cast<float>(search_conf.filter_threshold.value());

auto nq = dataset->GetRows();
auto dim = dataset->GetDim();
auto xq = dataset->GetTensor();

std::vector<folly::Future<folly::Unit>> futs;
futs.reserve(nq);
auto vec = std::vector<IndexNode::IteratorPtr>(nq, nullptr);
auto metric = search_conf.metric_type.value();
bool transform = metric != knowhere::metric::L2;

for (int i = 0; i < nq; i++) {
futs.emplace_back(search_pool_->push([&, id = i]() {
auto single_query = (DataType*)xq + id * dim;
auto it = std::make_shared<iterator>(transform, single_query, lsearch, beamwidth, filter_ratio, bitset,
pq_flash_index_.get());
it->initialize();
vec[id] = it;
}));
}

if (TryDiskANNCall([&]() { WaitAllSuccess(futs); }) != Status::success) {
return expected<std::vector<std::shared_ptr<IndexNode::iterator>>>::Err(Status::diskann_inner_error,
"some ann-iterator failed");
}

return vec;
}

template <typename DataType>
expected<DataSetPtr>
DiskANNIndexNode<DataType>::Search(const DataSetPtr dataset, std::unique_ptr<Config> cfg,
Expand Down Expand Up @@ -586,66 +664,6 @@ DiskANNIndexNode<DataType>::Search(const DataSetPtr dataset, std::unique_ptr<Con
return res;
}

template <typename DataType>
expected<DataSetPtr>
DiskANNIndexNode<DataType>::RangeSearch(const DataSetPtr dataset, std::unique_ptr<Config> cfg,
const BitsetView& bitset) const {
if (!is_prepared_.load() || !pq_flash_index_) {
LOG_KNOWHERE_ERROR_ << "Failed to load diskann.";
return expected<DataSetPtr>::Err(Status::empty_index, "index not loaded");
}

auto search_conf = static_cast<const DiskANNConfig&>(*cfg);
if (!CheckMetric(search_conf.metric_type.value())) {
return expected<DataSetPtr>::Err(Status::invalid_metric_type,
fmt::format("unknown metric type: {}", search_conf.metric_type.value()));
}
if (search_conf.min_k.value() > search_conf.max_k.value()) {
LOG_KNOWHERE_ERROR_ << "min_k should be smaller than max_k";
return expected<DataSetPtr>::Err(Status::out_of_range_in_json, "min_k should be smaller than max_k");
}
auto beamwidth = static_cast<uint64_t>(search_conf.beamwidth.value());
auto min_k = static_cast<uint64_t>(search_conf.min_k.value());
auto max_k = static_cast<uint64_t>(search_conf.max_k.value());

auto radius = search_conf.radius.value();
auto range_filter = search_conf.range_filter.value();
bool is_ip = (pq_flash_index_->get_metric() == diskann::Metric::INNER_PRODUCT ||
pq_flash_index_->get_metric() == diskann::Metric::COSINE);

auto dim = dataset->GetDim();
auto nq = dataset->GetRows();
auto xq = static_cast<const DataType*>(dataset->GetTensor());

std::vector<std::vector<int64_t>> result_id_array(nq);
std::vector<std::vector<DistType>> result_dist_array(nq);

std::vector<folly::Future<folly::Unit>> futures;
futures.reserve(nq);
for (int64_t row = 0; row < nq; ++row) {
futures.emplace_back(search_pool_->push([&, index = row]() {
diskann::QueryStats stats;
pq_flash_index_->range_search(xq + (index * dim), radius, min_k, max_k, result_id_array[index],
result_dist_array[index], beamwidth, bitset, &stats);
#ifdef NOT_COMPILE_FOR_SWIG
knowhere_diskann_range_search_iters.Observe(stats.n_iters);
#endif
// filter range search result
if (search_conf.range_filter.value() != defaultRangeFilter) {
FilterRangeSearchResultForOneNq(result_dist_array[index], result_id_array[index], is_ip, radius,
range_filter);
}
}));
}
if (TryDiskANNCall([&]() { WaitAllSuccess(futures); }) != Status::success) {
return expected<DataSetPtr>::Err(Status::diskann_inner_error, "some search failed");
}

auto range_search_result =
GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius, search_conf.range_filter.value());
return GenResultDataSet(nq, std::move(range_search_result));
}

/*
* Get raw vector data given their ids.
* It first tries to get data from cache, if failed, it will try to get data from disk.
Expand Down
11 changes: 0 additions & 11 deletions tests/ut/test_diskann.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,6 @@ TEST_CASE("Invalid diskann params test", "[diskann]") {
REQUIRE_FALSE(res.has_value());
REQUIRE(res.error() == knowhere::Status::out_of_range_in_json);
}
// min_k > max_k
{
test_json = test_gen();
test_json["min_k"] = 10000;
test_json["max_k"] = 100;
auto res = diskann.RangeSearch(query_ds, test_json, nullptr);
REQUIRE_FALSE(res.has_value());
REQUIRE(res.error() == knowhere::Status::out_of_range_in_json);
}
#endif
}
fs::remove_all(kDir);
Expand Down Expand Up @@ -223,8 +214,6 @@ base_search() {
knowhere::Json json = base_gen();
json["index_prefix"] = metric_dir_map[metric_str];
json["beamwidth"] = 8;
json["min_k"] = 10;
json["max_k"] = 8000;
return json;
};

Expand Down
108 changes: 101 additions & 7 deletions thirdparty/DiskANN/include/diskann/pq_flash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,101 @@ namespace diskann {
QueryScratch<T> scratch;
};

/** Algorithm Introduction for diskann-iterator
* First, two unbounded min-heaps are maintained: `retset` and `candidates`,
* sorted by *pq_dist*. (Similar to the navigation search path of hnswlib-hnsw
* with ef=1.)
*
* When visiting candidates, another unbounded minimum heap, `full_retset`, is
* maintained to collect all *full_dist* along the paths.
* Each `iterator->next` is called, `candidates` are continually visited until
* the optimal candidates are farther than the optimal retset.
* "candidates.top.pq_dist < retset.top.pq_dist"
*
* Once the condition is met, we consider the `retset` to have reached its
* current optimal state, and we stop visiting new candidates. The top element
* of `full_retset` is returned as the final result. At this point, the top of
* `retset` has also completed its task and will be directly popped.
*
* It is important to note that "ef=1" is clearly insufficient in terms of
* accuracy. To address this, the `lsearch` parameter is provided, which
* allows the `workspace` to effectively make an *additional* `lsearch`
* iterations each time `iterator->next` is called. Specifically, this means
* "good_pq_res_count - next_count >= lsearch".
* The additional results will be sorted and saved by the upper-level
* `iterator.res_` through `next_batch` func with `backup_res`.
*/
template<typename T>
struct IteratorWorkspace {
IteratorWorkspace(const T *query_data, const diskann::Metric metric,
const uint64_t aligned_dim, const uint64_t data_dim,
const float alpha, const uint64_t lsearch,
const uint64_t beam_width, const float filter_ratio,
const float max_base_norm,
const knowhere::BitsetView &bitset);

~IteratorWorkspace();

bool is_good_pq_enough();

bool has_candidates();

bool should_visit_next_candidate();

void insert_to_pq(unsigned id, float dist, bool valid);

void insert_to_full(unsigned id, float dist);

void pop_pq_retset();

void move_full_retset_to_backup();

void move_last_full_retset_to_backup();

uint64_t q_dim = 0;
uint64_t lsearch = 0;
uint64_t beam_width = 0;
float filter_ratio = 0;
Metric metric = Metric::L2;
float alpha = 0;
float acc_alpha = 0;
bool initialized = false;

T *aligned_query_T = nullptr;
float *aligned_query_float = nullptr;
tsl::robin_set<_u64> *visited = nullptr;
float query_norm = 0.0f;
float max_base_norm = 0.0f;
bool not_l2_but_zero = false; // (cosine or ip) and query_norm == 0.
std::vector<unsigned> frontier;
std::vector<AlignedRead> frontier_read_reqs;
const knowhere::BitsetView bitset;

std::vector<std::pair<unsigned, char *>> frontier_nhoods;
std::vector<std::pair<unsigned, std::pair<unsigned, unsigned *>>>
cached_nhoods;

struct MinHeapCompareForSimpleNeighbor {
bool operator()(const SimpleNeighbor &a, const SimpleNeighbor &b) {
return a.distance > b.distance;
}
};
std::priority_queue<SimpleNeighbor, std::vector<SimpleNeighbor>,
MinHeapCompareForSimpleNeighbor>
full_retset;
std::priority_queue<SimpleNeighbor, std::vector<SimpleNeighbor>,
MinHeapCompareForSimpleNeighbor>
retset;
std::priority_queue<SimpleNeighbor, std::vector<SimpleNeighbor>,
MinHeapCompareForSimpleNeighbor>
candidates;

size_t good_pq_res_count = 0;
size_t next_count = 0;

std::vector<knowhere::DistId> backup_res;
};

template<typename T>
class PQFlashIndex {
public:
Expand Down Expand Up @@ -105,13 +200,6 @@ namespace diskann {
knowhere::BitsetView bitset_view = nullptr,
const float filter_ratio = -1.0f);

_u32 range_search(const T *query1, const double range,
const _u64 min_l_search, const _u64 max_l_search,
std::vector<_s64> &indices, std::vector<float> &distances,
const _u64 beam_width,
knowhere::BitsetView bitset_view = nullptr,
QueryStats *stats = nullptr);

void get_vector_by_ids(const int64_t *ids, const int64_t n,
T *const output_data);

Expand All @@ -129,6 +217,12 @@ namespace diskann {

diskann::Metric get_metric() const noexcept;

void getIteratorNextBatch(IteratorWorkspace<T> *workspace);

std::unique_ptr<IteratorWorkspace<T>> getIteratorWorkspace(
const T *query_data, const uint64_t lsearch, const uint64_t beam_width,
const float filter_ratio, const knowhere::BitsetView &bitset);

_u64 cal_size();

// for async cache making task
Expand Down
Loading

0 comments on commit 690ade0

Please sign in to comment.