Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fail to load ivf_flat(metric = cosine) with DeserializeFromFile #841

Merged
merged 1 commit into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 69 additions & 36 deletions tests/ut/test_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
namespace {
constexpr float kKnnRecallThreshold = 0.6f;
constexpr float kBruteForceRecallThreshold = 0.95f;
constexpr const char* kMmapIndexPath = "/tmp/knowhere_dense_mmap_index_test";
} // namespace

TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") {
Expand Down Expand Up @@ -166,49 +167,81 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") {
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ8, hnsw_gen),
}));
auto idx_expected = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version);
if (name == knowhere::IndexEnum::INDEX_FAISS_SCANN) {
// need to check cpu model for scann
if (!faiss::support_pq_fast_scan) {
REQUIRE(idx_expected.error() == knowhere::Status::invalid_index_error);
return;
knowhere::BinarySet bs;
// build process
{
auto idx_expected = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version);
if (name == knowhere::IndexEnum::INDEX_FAISS_SCANN) {
// need to check cpu model for scann
if (!faiss::support_pq_fast_scan) {
REQUIRE(idx_expected.error() == knowhere::Status::invalid_index_error);
return;
}
}
auto idx = idx_expected.value();
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
knowhere::Json json = knowhere::Json::parse(cfg_json);
REQUIRE(idx.Type() == name);
REQUIRE(idx.Build(train_ds, json) == knowhere::Status::success);
REQUIRE(idx.Size() > 0);
REQUIRE(idx.Count() == nb);

REQUIRE(idx.Serialize(bs) == knowhere::Status::success);
}
auto idx = idx_expected.value();
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
knowhere::Json json = knowhere::Json::parse(cfg_json);
REQUIRE(idx.Type() == name);
REQUIRE(idx.Build(train_ds, json) == knowhere::Status::success);
REQUIRE(idx.Size() > 0);
REQUIRE(idx.Count() == nb);
// search process
auto load_with_mmap = GENERATE(as<bool>{}, true, false);
{
auto idx_expected = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version);
auto idx = idx_expected.value();
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
knowhere::Json json = knowhere::Json::parse(cfg_json);
// TODO: qianya(DeserializeFromFile need raw data path. Next pr will remove raw data in ivf sq cc index, and
// use a knowhere struct to maintain raw data)
if (load_with_mmap && knowhere::KnowhereCheck::SupportMmapIndexTypeCheck(name) &&
name != knowhere::IndexEnum::INDEX_FAISS_IVFSQ_CC) {
auto binary = bs.GetByName(idx.Type());
auto data = binary->data.get();
auto size = binary->size;
std::remove(kMmapIndexPath);
std::ofstream out(kMmapIndexPath, std::ios::binary);
out.write((const char*)data, size);
out.close();
json["enable_mmap"] = true;
REQUIRE(idx.DeserializeFromFile(kMmapIndexPath, json) == knowhere::Status::success);
} else {
REQUIRE(idx.Deserialize(bs, json) == knowhere::Status::success);
}

knowhere::BinarySet bs;
REQUIRE(idx.Serialize(bs) == knowhere::Status::success);
REQUIRE(idx.Deserialize(bs, json) == knowhere::Status::success);
// TODO: qianya (IVFSQ_CC deserialize casted from the IVFSQ directly, which will cause the hasRawData reference
// to an uncertain address)
if (name != knowhere::IndexEnum::INDEX_FAISS_IVFSQ_CC) {
REQUIRE(idx.HasRawData(json[knowhere::meta::METRIC_TYPE]) ==
knowhere::IndexStaticFaced<knowhere::fp32>::HasRawData(name, version, json));
}
// TODO: qianya (IVFSQ_CC deserialize casted from the IVFSQ directly, which will cause the hasRawData
// reference to an uncertain address)
if (name != knowhere::IndexEnum::INDEX_FAISS_IVFSQ_CC) {
REQUIRE(idx.HasRawData(json[knowhere::meta::METRIC_TYPE]) ==
knowhere::IndexStaticFaced<knowhere::fp32>::HasRawData(name, version, json));
}

auto results = idx.Search(query_ds, json, nullptr);
REQUIRE(results.has_value());
float recall = GetKNNRecall(*gt.value(), *results.value());
bool scann_without_raw_data =
(name == knowhere::IndexEnum::INDEX_FAISS_SCANN && scann_gen2().dump() == cfg_json);
if (name != knowhere::IndexEnum::INDEX_FAISS_IVFPQ && !scann_without_raw_data) {
REQUIRE(recall > kKnnRecallThreshold);
}
auto results = idx.Search(query_ds, json, nullptr);
REQUIRE(results.has_value());
float recall = GetKNNRecall(*gt.value(), *results.value());
bool scann_without_raw_data =
(name == knowhere::IndexEnum::INDEX_FAISS_SCANN && scann_gen2().dump() == cfg_json);
if (name != knowhere::IndexEnum::INDEX_FAISS_IVFPQ && !scann_without_raw_data) {
REQUIRE(recall > kKnnRecallThreshold);
}

if (metric == knowhere::metric::COSINE) {
if (name != knowhere::IndexEnum::INDEX_FAISS_IVFSQ8 && name != knowhere::IndexEnum::INDEX_FAISS_IVFPQ &&
name != knowhere::IndexEnum::INDEX_HNSW_SQ8 && name != knowhere::IndexEnum::INDEX_HNSW_SQ8_REFINE &&
name != knowhere::IndexEnum::INDEX_FAISS_IVFSQ_CC && !scann_without_raw_data) {
REQUIRE(CheckDistanceInScope(*results.value(), topk, -1.00001, 1.00001));
if (metric == knowhere::metric::COSINE) {
if (name != knowhere::IndexEnum::INDEX_FAISS_IVFSQ8 && name != knowhere::IndexEnum::INDEX_FAISS_IVFPQ &&
name != knowhere::IndexEnum::INDEX_HNSW_SQ8 && name != knowhere::IndexEnum::INDEX_HNSW_SQ8_REFINE &&
name != knowhere::IndexEnum::INDEX_FAISS_IVFSQ_CC && !scann_without_raw_data) {
REQUIRE(CheckDistanceInScope(*results.value(), topk, -1.00001, 1.00001));
}
}
}
if (load_with_mmap && knowhere::KnowhereCheck::SupportMmapIndexTypeCheck(name) &&
name != knowhere::IndexEnum::INDEX_FAISS_IVFSQ_CC) {
std::remove(kMmapIndexPath);
}
}

SECTION("Test Range Search") {
Expand Down
29 changes: 26 additions & 3 deletions thirdparty/faiss/faiss/invlists/OnDiskInvertedLists.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,25 @@ void OnDiskInvertedLists::set_all_lists_sizes(const size_t* sizes) {
for (size_t i = 0; i < nlist; i++) {
lists[i].offset = ofs;
lists[i].capacity = lists[i].size = sizes[i];
ofs += sizes[i] * (sizeof(idx_t) + code_size);
if (this->with_norm) {
ofs += sizes[i] * (sizeof(idx_t) + code_size + sizeof(float));
} else {
ofs += sizes[i] * (sizeof(idx_t) + code_size);
}
}
}

const float* OnDiskInvertedLists::get_code_norms(
size_t list_no,
size_t /*offset*/) const {
if (with_norm) {
if (lists[list_no].offset == INVALID_OFFSET) {
return nullptr;
}
return (const float*)(ptr + lists[list_no].offset +
(code_size + sizeof(idx_t)) * lists[list_no].capacity);
} else {
return nullptr;
}
}

Expand Down Expand Up @@ -761,14 +779,15 @@ InvertedLists* OnDiskInvertedListsIOHook::read(IOReader* f, int io_flags)
/** read from a ArrayInvertedLists into this invertedlist type */
InvertedLists* OnDiskInvertedListsIOHook::read_ArrayInvertedLists(
IOReader* f,
int /* io_flags */,
int io_flags,
size_t nlist,
size_t code_size,
const std::vector<size_t>& sizes) const {
auto ails = new OnDiskInvertedLists();
ails->nlist = nlist;
ails->code_size = code_size;
ails->read_only = true;
ails->with_norm = io_flags & IO_FLAG_WITH_NORM;
ails->lists.resize(nlist);

FileIOReader* reader = dynamic_cast<FileIOReader*>(f);
Expand Down Expand Up @@ -798,7 +817,11 @@ InvertedLists* OnDiskInvertedListsIOHook::read_ArrayInvertedLists(
OnDiskInvertedLists::List& l = ails->lists[i];
l.size = l.capacity = sizes[i];
l.offset = o;
o += l.size * (sizeof(idx_t) + ails->code_size);
if (ails->with_norm) {
o += l.size * (sizeof(idx_t) + ails->code_size + sizeof(float));
} else {
o += l.size * (sizeof(idx_t) + ails->code_size);
}
}
// resume normal reading of file
fseek(fdesc, o, SEEK_SET);
Expand Down
3 changes: 3 additions & 0 deletions thirdparty/faiss/faiss/invlists/OnDiskInvertedLists.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ struct OnDiskInvertedLists : InvertedLists {
size_t totsize;
uint8_t* ptr; // mmap base pointer
bool read_only; /// are inverted lists mapped read-only
bool with_norm = false;

OnDiskInvertedLists(size_t nlist, size_t code_size, const char* filename);

Expand Down Expand Up @@ -138,6 +139,8 @@ struct OnDiskInvertedLists : InvertedLists {

// empty constructor for the I/O functions
OnDiskInvertedLists();

const float* get_code_norms(size_t list_no, size_t offset) const override;
};

struct OnDiskInvertedListsIOHook : InvertedListsIOHook {
Expand Down
Loading