From b6f01a6665bf25a2abb95b03b0ca2032603c2749 Mon Sep 17 00:00:00 2001 From: Patrick Weizhi Xu Date: Mon, 26 Jun 2023 19:44:40 +0800 Subject: [PATCH] Fix GetVectorById Dim, Return Dataset and Set All HasRawData to False Except HNSW (#957) Signed-off-by: Patrick Weizhi Xu --- knowhere/index/vector_index/IndexAnnoy.cpp | 3 ++- knowhere/index/vector_index/IndexAnnoy.h | 2 +- knowhere/index/vector_index/IndexBinaryIDMAP.cpp | 3 ++- knowhere/index/vector_index/IndexBinaryIDMAP.h | 2 +- knowhere/index/vector_index/IndexBinaryIVF.cpp | 3 ++- knowhere/index/vector_index/IndexBinaryIVF.h | 2 +- knowhere/index/vector_index/IndexHNSW.cpp | 3 ++- knowhere/index/vector_index/IndexIDMAP.cpp | 3 ++- knowhere/index/vector_index/IndexIDMAP.h | 2 +- knowhere/index/vector_index/IndexIVF.cpp | 3 ++- knowhere/index/vector_index/IndexIVF.h | 2 +- knowhere/index/vector_index/adapter/VectorAdapter.cpp | 4 +++- knowhere/index/vector_index/adapter/VectorAdapter.h | 3 +-- knowhere/index/vector_offset_index/IndexIVF_NM.cpp | 3 ++- knowhere/index/vector_offset_index/IndexIVF_NM.h | 2 +- unittest/test_annoy.cpp | 2 +- unittest/test_binaryidmap.cpp | 2 +- unittest/test_binaryivf.cpp | 2 +- unittest/test_idmap.cpp | 2 +- unittest/test_ivf_nm.cpp | 2 +- 20 files changed, 29 insertions(+), 21 deletions(-) diff --git a/knowhere/index/vector_index/IndexAnnoy.cpp b/knowhere/index/vector_index/IndexAnnoy.cpp index eea236e99..8e229d245 100644 --- a/knowhere/index/vector_index/IndexAnnoy.cpp +++ b/knowhere/index/vector_index/IndexAnnoy.cpp @@ -127,6 +127,7 @@ IndexAnnoy::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) { } GET_DATA_WITH_IDS(dataset_ptr) + auto dim = Dim(); float* p_x = nullptr; try { @@ -142,7 +143,7 @@ IndexAnnoy::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) { } KNOWHERE_THROW_MSG(e.what()); } - return GenResultDataset(p_x); + return GenResultDataset(rows, dim, p_x); } DatasetPtr diff --git a/knowhere/index/vector_index/IndexAnnoy.h b/knowhere/index/vector_index/IndexAnnoy.h index 81d79c53e..141d0067a 100644 --- a/knowhere/index/vector_index/IndexAnnoy.h +++ b/knowhere/index/vector_index/IndexAnnoy.h @@ -48,7 +48,7 @@ class IndexAnnoy : public VecIndex { bool HasRawData(const std::string& /*metric_type*/) const override { - return true; + return false; } DatasetPtr diff --git a/knowhere/index/vector_index/IndexBinaryIDMAP.cpp b/knowhere/index/vector_index/IndexBinaryIDMAP.cpp index 9dabf1a55..437da0af9 100644 --- a/knowhere/index/vector_index/IndexBinaryIDMAP.cpp +++ b/knowhere/index/vector_index/IndexBinaryIDMAP.cpp @@ -43,6 +43,7 @@ BinaryIDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) } GET_DATA_WITH_IDS(dataset_ptr) + auto dim = Dim(); uint8_t* p_x = nullptr; auto release_when_exception = [&]() { @@ -59,7 +60,7 @@ BinaryIDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) KNOWHERE_THROW_IF_NOT_FMT(id >= 0 && id < bin_idmap_index->ntotal, "invalid id %ld", id); bin_idmap_index->reconstruct(id, p_x + i * dim / 8); } - return GenResultDataset(p_x); + return GenResultDataset(rows, dim, p_x); } catch (faiss::FaissException& e) { release_when_exception(); KNOWHERE_THROW_MSG(e.what()); diff --git a/knowhere/index/vector_index/IndexBinaryIDMAP.h b/knowhere/index/vector_index/IndexBinaryIDMAP.h index bd4835405..bafb7d9df 100644 --- a/knowhere/index/vector_index/IndexBinaryIDMAP.h +++ b/knowhere/index/vector_index/IndexBinaryIDMAP.h @@ -48,7 +48,7 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex { bool HasRawData(const std::string& /*metric_type*/) const override { - return true; + return false; } DatasetPtr diff --git a/knowhere/index/vector_index/IndexBinaryIVF.cpp b/knowhere/index/vector_index/IndexBinaryIVF.cpp index 6fa054fa7..c623e948f 100644 --- a/knowhere/index/vector_index/IndexBinaryIVF.cpp +++ b/knowhere/index/vector_index/IndexBinaryIVF.cpp @@ -54,6 +54,7 @@ BinaryIVF::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) { } GET_DATA_WITH_IDS(dataset_ptr) + auto dim = Dim(); uint8_t* p_x = nullptr; auto release_when_exception = [&]() { @@ -71,7 +72,7 @@ BinaryIVF::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) { KNOWHERE_THROW_IF_NOT_FMT(id >= 0 && id < bin_ivf_index->ntotal, "invalid id %ld", id); bin_ivf_index->reconstruct(id, p_x + i * dim / 8); } - return GenResultDataset(p_x); + return GenResultDataset(rows, dim, p_x); } catch (faiss::FaissException& e) { release_when_exception(); KNOWHERE_THROW_MSG(e.what()); diff --git a/knowhere/index/vector_index/IndexBinaryIVF.h b/knowhere/index/vector_index/IndexBinaryIVF.h index 2cb091eaa..779bfeb85 100644 --- a/knowhere/index/vector_index/IndexBinaryIVF.h +++ b/knowhere/index/vector_index/IndexBinaryIVF.h @@ -53,7 +53,7 @@ class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex { bool HasRawData(const std::string& /*metric_type*/) const override { - return true; + return false; } DatasetPtr diff --git a/knowhere/index/vector_index/IndexHNSW.cpp b/knowhere/index/vector_index/IndexHNSW.cpp index 760fde183..71a2a2f46 100644 --- a/knowhere/index/vector_index/IndexHNSW.cpp +++ b/knowhere/index/vector_index/IndexHNSW.cpp @@ -123,6 +123,7 @@ IndexHNSW::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) { } GET_DATA_WITH_IDS(dataset_ptr) + auto dim = Dim(); float* p_x = nullptr; try { @@ -138,7 +139,7 @@ IndexHNSW::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) { } KNOWHERE_THROW_MSG(e.what()); } - return GenResultDataset(p_x); + return GenResultDataset(rows, dim, p_x); } DatasetPtr diff --git a/knowhere/index/vector_index/IndexIDMAP.cpp b/knowhere/index/vector_index/IndexIDMAP.cpp index 66114ae5e..216b4e5b3 100644 --- a/knowhere/index/vector_index/IndexIDMAP.cpp +++ b/knowhere/index/vector_index/IndexIDMAP.cpp @@ -76,6 +76,7 @@ IDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) { } GET_DATA_WITH_IDS(dataset_ptr) + auto dim = Dim(); float* p_x = nullptr; auto release_when_exception = [&]() { @@ -92,7 +93,7 @@ IDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) { KNOWHERE_THROW_IF_NOT_FMT(id >= 0 && id < idmap_index->ntotal, "invalid id %ld", id); idmap_index->reconstruct(id, p_x + i * dim); } - return GenResultDataset(p_x); + return GenResultDataset(rows, dim, p_x); } catch (faiss::FaissException& e) { release_when_exception(); KNOWHERE_THROW_MSG(e.what()); diff --git a/knowhere/index/vector_index/IndexIDMAP.h b/knowhere/index/vector_index/IndexIDMAP.h index 73d311f17..b3041f13c 100644 --- a/knowhere/index/vector_index/IndexIDMAP.h +++ b/knowhere/index/vector_index/IndexIDMAP.h @@ -48,7 +48,7 @@ class IDMAP : public VecIndex, public FaissBaseIndex { bool HasRawData(const std::string& /*metric_type*/) const override { - return true; + return false; } DatasetPtr diff --git a/knowhere/index/vector_index/IndexIVF.cpp b/knowhere/index/vector_index/IndexIVF.cpp index 16de8b420..9827667ae 100644 --- a/knowhere/index/vector_index/IndexIVF.cpp +++ b/knowhere/index/vector_index/IndexIVF.cpp @@ -98,6 +98,7 @@ IVF::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) { } GET_DATA_WITH_IDS(dataset_ptr) + auto dim = Dim(); float* p_x = nullptr; auto release_when_exception = [&]() { @@ -115,7 +116,7 @@ IVF::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) { KNOWHERE_THROW_IF_NOT_FMT(id >= 0 && id < ivf_index->ntotal, "invalid id %ld", id); ivf_index->reconstruct(id, p_x + i * dim); } - return GenResultDataset(p_x); + return GenResultDataset(rows, dim, p_x); } catch (faiss::FaissException& e) { release_when_exception(); KNOWHERE_THROW_MSG(e.what()); diff --git a/knowhere/index/vector_index/IndexIVF.h b/knowhere/index/vector_index/IndexIVF.h index 4a4dc5506..fee223638 100644 --- a/knowhere/index/vector_index/IndexIVF.h +++ b/knowhere/index/vector_index/IndexIVF.h @@ -53,7 +53,7 @@ class IVF : public VecIndex, public FaissBaseIndex { bool HasRawData(const std::string& /*metric_type*/) const override { - return true; + return false; } DatasetPtr diff --git a/knowhere/index/vector_index/adapter/VectorAdapter.cpp b/knowhere/index/vector_index/adapter/VectorAdapter.cpp index 3ca4acb71..30fe7d854 100644 --- a/knowhere/index/vector_index/adapter/VectorAdapter.cpp +++ b/knowhere/index/vector_index/adapter/VectorAdapter.cpp @@ -36,8 +36,10 @@ GenDatasetWithIds(const int64_t n, const int64_t dim, const int64_t* ids) { } DatasetPtr -GenResultDataset(const void* tensor) { +GenResultDataset(const int64_t rows, const int64_t dim, const void* tensor) { auto ret_ds = std::make_shared(); + SetDatasetRows(ret_ds, rows); + SetDatasetDim(ret_ds, dim); SetDatasetOutputTensor(ret_ds, tensor); return ret_ds; } diff --git a/knowhere/index/vector_index/adapter/VectorAdapter.h b/knowhere/index/vector_index/adapter/VectorAdapter.h index d0da8ed78..c7b389d50 100644 --- a/knowhere/index/vector_index/adapter/VectorAdapter.h +++ b/knowhere/index/vector_index/adapter/VectorAdapter.h @@ -64,7 +64,6 @@ DEFINE_DATASET_SETTER(SetDatasetJsonIdSet, meta::JSON_ID_SET, const std::string) #define GET_DATA_WITH_IDS(ds_ptr) \ auto rows = knowhere::GetDatasetRows(ds_ptr); \ - auto dim = knowhere::GetDatasetDim(ds_ptr); \ auto p_ids = knowhere::GetDatasetInputIDs(ds_ptr); #define GET_TENSOR_DATA(ds_ptr) \ @@ -82,7 +81,7 @@ extern DatasetPtr GenDatasetWithIds(const int64_t n, const int64_t dim, const int64_t* ids); extern DatasetPtr -GenResultDataset(const void* tensor); +GenResultDataset(const int64_t rows, const int64_t dim, const void* tensor); extern DatasetPtr GenResultDataset(const int64_t* ids, const float* distance); diff --git a/knowhere/index/vector_offset_index/IndexIVF_NM.cpp b/knowhere/index/vector_offset_index/IndexIVF_NM.cpp index 7e24dbe9a..4bc2d3723 100644 --- a/knowhere/index/vector_offset_index/IndexIVF_NM.cpp +++ b/knowhere/index/vector_offset_index/IndexIVF_NM.cpp @@ -153,6 +153,7 @@ IVF_NM::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) { } GET_DATA_WITH_IDS(dataset_ptr) + auto dim = Dim(); float* p_x = nullptr; auto release_when_exception = [&]() { @@ -170,7 +171,7 @@ IVF_NM::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) { KNOWHERE_THROW_IF_NOT_FMT(id >= 0 && id < ivf_index->ntotal, "invalid id %ld", id); ivf_index->reconstruct_without_codes(id, p_x + i * dim); } - return GenResultDataset(p_x); + return GenResultDataset(rows, dim, p_x); } catch (faiss::FaissException& e) { release_when_exception(); KNOWHERE_THROW_MSG(e.what()); diff --git a/knowhere/index/vector_offset_index/IndexIVF_NM.h b/knowhere/index/vector_offset_index/IndexIVF_NM.h index 3a1b839d9..d088892b7 100644 --- a/knowhere/index/vector_offset_index/IndexIVF_NM.h +++ b/knowhere/index/vector_offset_index/IndexIVF_NM.h @@ -54,7 +54,7 @@ class IVF_NM : public VecIndex, public OffsetBaseIndex { bool HasRawData(const std::string& /*metric_type*/) const override { - return true; + return false; } DatasetPtr diff --git a/unittest/test_annoy.cpp b/unittest/test_annoy.cpp index c6ae257dc..10fc3ca53 100644 --- a/unittest/test_annoy.cpp +++ b/unittest/test_annoy.cpp @@ -57,7 +57,7 @@ TEST_P(AnnoyTest, annoy_basic) { ASSERT_EQ(index_->Dim(), dim); ASSERT_GT(index_->Size(), 0); - ASSERT_TRUE(index_->HasRawData(knowhere::GetMetaMetricType(conf_))); + ASSERT_FALSE(index_->HasRawData(knowhere::GetMetaMetricType(conf_))); auto result = index_->GetVectorById(id_dataset, conf_); AssertVec(result, base_dataset, id_dataset, nq, dim); diff --git a/unittest/test_binaryidmap.cpp b/unittest/test_binaryidmap.cpp index f280462ba..4b06fa84e 100644 --- a/unittest/test_binaryidmap.cpp +++ b/unittest/test_binaryidmap.cpp @@ -71,7 +71,7 @@ TEST_P(BinaryIDMAPTest, binaryidmap_basic) { EXPECT_EQ(index_->Dim(), dim); ASSERT_GT(index_->Size(), 0); - ASSERT_TRUE(index_->HasRawData(knowhere::GetMetaMetricType(conf_))); + ASSERT_FALSE(index_->HasRawData(knowhere::GetMetaMetricType(conf_))); auto result = index_->GetVectorById(id_dataset, conf_); AssertBinVec(result, base_dataset, id_dataset, nq, dim); diff --git a/unittest/test_binaryivf.cpp b/unittest/test_binaryivf.cpp index 0443a8b59..d9437cea2 100644 --- a/unittest/test_binaryivf.cpp +++ b/unittest/test_binaryivf.cpp @@ -68,7 +68,7 @@ TEST_P(BinaryIVFTest, binaryivf_basic) { EXPECT_EQ(index_->Dim(), dim); ASSERT_GT(index_->Size(), 0); - ASSERT_TRUE(index_->HasRawData(knowhere::GetMetaMetricType(conf_))); + ASSERT_FALSE(index_->HasRawData(knowhere::GetMetaMetricType(conf_))); auto result = index_->GetVectorById(id_dataset, conf_); AssertBinVec(result, base_dataset, id_dataset, nq, dim); diff --git a/unittest/test_idmap.cpp b/unittest/test_idmap.cpp index 5eb401fc9..af8035fcb 100644 --- a/unittest/test_idmap.cpp +++ b/unittest/test_idmap.cpp @@ -90,7 +90,7 @@ TEST_P(IDMAPTest, idmap_basic) { EXPECT_EQ(index_->Dim(), dim); ASSERT_GT(index_->Size(), 0); - ASSERT_TRUE(index_->HasRawData(knowhere::GetMetaMetricType(conf_))); + ASSERT_FALSE(index_->HasRawData(knowhere::GetMetaMetricType(conf_))); auto result = index_->GetVectorById(id_dataset, conf_); AssertVec(result, base_dataset, id_dataset, nq, dim); diff --git a/unittest/test_ivf_nm.cpp b/unittest/test_ivf_nm.cpp index 94b89fc07..8ebd6c276 100644 --- a/unittest/test_ivf_nm.cpp +++ b/unittest/test_ivf_nm.cpp @@ -105,7 +105,7 @@ TEST_P(IVFNMTest, ivfnm_basic) { LoadRawData(index_, base_dataset, conf_); if (index_mode_ == knowhere::IndexMode::MODE_CPU) { - ASSERT_TRUE(index_->HasRawData(knowhere::GetMetaMetricType(conf_))); + ASSERT_FALSE(index_->HasRawData(knowhere::GetMetaMetricType(conf_))); auto result = index_->GetVectorById(id_dataset, conf_); AssertVec(result, base_dataset, id_dataset, nq, dim);