Skip to content

Commit

Permalink
fix: fail to load ivf_flat(metric = cosine) with DeserializeFromFile (#…
Browse files Browse the repository at this point in the history
…841)

Signed-off-by: cqy123456 <[email protected]>
  • Loading branch information
cqy123456 authored Sep 20, 2024
1 parent 3e4454d commit 81d3b38
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 39 deletions.
105 changes: 69 additions & 36 deletions tests/ut/test_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
namespace {
constexpr float kKnnRecallThreshold = 0.6f;
constexpr float kBruteForceRecallThreshold = 0.95f;
constexpr const char* kMmapIndexPath = "/tmp/knowhere_dense_mmap_index_test";
} // namespace

TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") {
Expand Down Expand Up @@ -166,49 +167,81 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") {
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ8, hnsw_gen),
}));
auto idx_expected = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version);
if (name == knowhere::IndexEnum::INDEX_FAISS_SCANN) {
// need to check cpu model for scann
if (!faiss::support_pq_fast_scan) {
REQUIRE(idx_expected.error() == knowhere::Status::invalid_index_error);
return;
knowhere::BinarySet bs;
// build process
{
auto idx_expected = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version);
if (name == knowhere::IndexEnum::INDEX_FAISS_SCANN) {
// need to check cpu model for scann
if (!faiss::support_pq_fast_scan) {
REQUIRE(idx_expected.error() == knowhere::Status::invalid_index_error);
return;
}
}
auto idx = idx_expected.value();
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
knowhere::Json json = knowhere::Json::parse(cfg_json);
REQUIRE(idx.Type() == name);
REQUIRE(idx.Build(train_ds, json) == knowhere::Status::success);
REQUIRE(idx.Size() > 0);
REQUIRE(idx.Count() == nb);

REQUIRE(idx.Serialize(bs) == knowhere::Status::success);
}
auto idx = idx_expected.value();
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
knowhere::Json json = knowhere::Json::parse(cfg_json);
REQUIRE(idx.Type() == name);
REQUIRE(idx.Build(train_ds, json) == knowhere::Status::success);
REQUIRE(idx.Size() > 0);
REQUIRE(idx.Count() == nb);
// search process
auto load_with_mmap = GENERATE(as<bool>{}, true, false);
{
auto idx_expected = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version);
auto idx = idx_expected.value();
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
knowhere::Json json = knowhere::Json::parse(cfg_json);
// TODO: qianya(DeserializeFromFile need raw data path. Next pr will remove raw data in ivf sq cc index, and
// use a knowhere struct to maintain raw data)
if (load_with_mmap && knowhere::KnowhereCheck::SupportMmapIndexTypeCheck(name) &&
name != knowhere::IndexEnum::INDEX_FAISS_IVFSQ_CC) {
auto binary = bs.GetByName(idx.Type());
auto data = binary->data.get();
auto size = binary->size;
std::remove(kMmapIndexPath);
std::ofstream out(kMmapIndexPath, std::ios::binary);
out.write((const char*)data, size);
out.close();
json["enable_mmap"] = true;
REQUIRE(idx.DeserializeFromFile(kMmapIndexPath, json) == knowhere::Status::success);
} else {
REQUIRE(idx.Deserialize(bs, json) == knowhere::Status::success);
}

knowhere::BinarySet bs;
REQUIRE(idx.Serialize(bs) == knowhere::Status::success);
REQUIRE(idx.Deserialize(bs, json) == knowhere::Status::success);
// TODO: qianya (IVFSQ_CC deserialize casted from the IVFSQ directly, which will cause the hasRawData reference
// to an uncertain address)
if (name != knowhere::IndexEnum::INDEX_FAISS_IVFSQ_CC) {
REQUIRE(idx.HasRawData(json[knowhere::meta::METRIC_TYPE]) ==
knowhere::IndexStaticFaced<knowhere::fp32>::HasRawData(name, version, json));
}
// TODO: qianya (IVFSQ_CC deserialize casted from the IVFSQ directly, which will cause the hasRawData
// reference to an uncertain address)
if (name != knowhere::IndexEnum::INDEX_FAISS_IVFSQ_CC) {
REQUIRE(idx.HasRawData(json[knowhere::meta::METRIC_TYPE]) ==
knowhere::IndexStaticFaced<knowhere::fp32>::HasRawData(name, version, json));
}

auto results = idx.Search(query_ds, json, nullptr);
REQUIRE(results.has_value());
float recall = GetKNNRecall(*gt.value(), *results.value());
bool scann_without_raw_data =
(name == knowhere::IndexEnum::INDEX_FAISS_SCANN && scann_gen2().dump() == cfg_json);
if (name != knowhere::IndexEnum::INDEX_FAISS_IVFPQ && !scann_without_raw_data) {
REQUIRE(recall > kKnnRecallThreshold);
}
auto results = idx.Search(query_ds, json, nullptr);
REQUIRE(results.has_value());
float recall = GetKNNRecall(*gt.value(), *results.value());
bool scann_without_raw_data =
(name == knowhere::IndexEnum::INDEX_FAISS_SCANN && scann_gen2().dump() == cfg_json);
if (name != knowhere::IndexEnum::INDEX_FAISS_IVFPQ && !scann_without_raw_data) {
REQUIRE(recall > kKnnRecallThreshold);
}

if (metric == knowhere::metric::COSINE) {
if (name != knowhere::IndexEnum::INDEX_FAISS_IVFSQ8 && name != knowhere::IndexEnum::INDEX_FAISS_IVFPQ &&
name != knowhere::IndexEnum::INDEX_HNSW_SQ8 && name != knowhere::IndexEnum::INDEX_HNSW_SQ8_REFINE &&
name != knowhere::IndexEnum::INDEX_FAISS_IVFSQ_CC && !scann_without_raw_data) {
REQUIRE(CheckDistanceInScope(*results.value(), topk, -1.00001, 1.00001));
if (metric == knowhere::metric::COSINE) {
if (name != knowhere::IndexEnum::INDEX_FAISS_IVFSQ8 && name != knowhere::IndexEnum::INDEX_FAISS_IVFPQ &&
name != knowhere::IndexEnum::INDEX_HNSW_SQ8 && name != knowhere::IndexEnum::INDEX_HNSW_SQ8_REFINE &&
name != knowhere::IndexEnum::INDEX_FAISS_IVFSQ_CC && !scann_without_raw_data) {
REQUIRE(CheckDistanceInScope(*results.value(), topk, -1.00001, 1.00001));
}
}
}
if (load_with_mmap && knowhere::KnowhereCheck::SupportMmapIndexTypeCheck(name) &&
name != knowhere::IndexEnum::INDEX_FAISS_IVFSQ_CC) {
std::remove(kMmapIndexPath);
}
}

SECTION("Test Range Search") {
Expand Down
29 changes: 26 additions & 3 deletions thirdparty/faiss/faiss/invlists/OnDiskInvertedLists.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,25 @@ void OnDiskInvertedLists::set_all_lists_sizes(const size_t* sizes) {
for (size_t i = 0; i < nlist; i++) {
lists[i].offset = ofs;
lists[i].capacity = lists[i].size = sizes[i];
ofs += sizes[i] * (sizeof(idx_t) + code_size);
if (this->with_norm) {
ofs += sizes[i] * (sizeof(idx_t) + code_size + sizeof(float));
} else {
ofs += sizes[i] * (sizeof(idx_t) + code_size);
}
}
}

const float* OnDiskInvertedLists::get_code_norms(
size_t list_no,
size_t /*offset*/) const {
if (with_norm) {
if (lists[list_no].offset == INVALID_OFFSET) {
return nullptr;
}
return (const float*)(ptr + lists[list_no].offset +
(code_size + sizeof(idx_t)) * lists[list_no].capacity);
} else {
return nullptr;
}
}

Expand Down Expand Up @@ -761,14 +779,15 @@ InvertedLists* OnDiskInvertedListsIOHook::read(IOReader* f, int io_flags)
/** read from a ArrayInvertedLists into this invertedlist type */
InvertedLists* OnDiskInvertedListsIOHook::read_ArrayInvertedLists(
IOReader* f,
int /* io_flags */,
int io_flags,
size_t nlist,
size_t code_size,
const std::vector<size_t>& sizes) const {
auto ails = new OnDiskInvertedLists();
ails->nlist = nlist;
ails->code_size = code_size;
ails->read_only = true;
ails->with_norm = io_flags & IO_FLAG_WITH_NORM;
ails->lists.resize(nlist);

FileIOReader* reader = dynamic_cast<FileIOReader*>(f);
Expand Down Expand Up @@ -798,7 +817,11 @@ InvertedLists* OnDiskInvertedListsIOHook::read_ArrayInvertedLists(
OnDiskInvertedLists::List& l = ails->lists[i];
l.size = l.capacity = sizes[i];
l.offset = o;
o += l.size * (sizeof(idx_t) + ails->code_size);
if (ails->with_norm) {
o += l.size * (sizeof(idx_t) + ails->code_size + sizeof(float));
} else {
o += l.size * (sizeof(idx_t) + ails->code_size);
}
}
// resume normal reading of file
fseek(fdesc, o, SEEK_SET);
Expand Down
3 changes: 3 additions & 0 deletions thirdparty/faiss/faiss/invlists/OnDiskInvertedLists.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ struct OnDiskInvertedLists : InvertedLists {
size_t totsize;
uint8_t* ptr; // mmap base pointer
bool read_only; /// are inverted lists mapped read-only
bool with_norm = false;

OnDiskInvertedLists(size_t nlist, size_t code_size, const char* filename);

Expand Down Expand Up @@ -138,6 +139,8 @@ struct OnDiskInvertedLists : InvertedLists {

// empty constructor for the I/O functions
OnDiskInvertedLists();

const float* get_code_norms(size_t list_no, size_t offset) const override;
};

struct OnDiskInvertedListsIOHook : InvertedListsIOHook {
Expand Down

0 comments on commit 81d3b38

Please sign in to comment.