diff --git a/include/knowhere/config.h b/include/knowhere/config.h index 3bb3088da..7596be4cc 100644 --- a/include/knowhere/config.h +++ b/include/knowhere/config.h @@ -622,14 +622,22 @@ class Config { const float defaultRangeFilter = 1.0f / 0.0; +template +knowhere::Status +CheckConfig(const std::string& index_type, const int32_t& version, knowhere::Json& json, + knowhere::PARAM_TYPE param_type, std::string& msg); + class BaseConfig : public Config { public: + CFG_INT dim; // just used for config verify CFG_STRING metric_type; CFG_INT k; CFG_INT num_build_thread; CFG_BOOL retrieve_friendly; CFG_STRING data_path; CFG_STRING index_prefix; + + CFG_FLOAT vec_field_size_gb; // for distance metrics, we search for vectors with distance in [range_filter, radius). // for similarity metrics, we search for vectors with similarity in (radius, range_filter]. CFG_FLOAT radius; @@ -659,6 +667,10 @@ class BaseConfig : public Config { CFG_FLOAT bm25_b; CFG_FLOAT bm25_avgdl; KNOHWERE_DECLARE_CONFIG(BaseConfig) { + KNOWHERE_CONFIG_DECLARE_FIELD(dim) + .allow_empty_without_default() + .description("vector dim") + .for_train(); KNOWHERE_CONFIG_DECLARE_FIELD(metric_type) .set_default("L2") .description("metric type") @@ -679,6 +691,10 @@ class BaseConfig : public Config { .allow_empty_without_default() .for_train() .for_deserialize(); + KNOWHERE_CONFIG_DECLARE_FIELD(vec_field_size_gb) + .description("vector filed size in GB.") + .set_default(0) + .for_train(); KNOWHERE_CONFIG_DECLARE_FIELD(k) .set_default(10) .description("search for top k similar vector.") diff --git a/include/knowhere/feature.h b/include/knowhere/feature.h new file mode 100644 index 000000000..2348adb2d --- /dev/null +++ b/include/knowhere/feature.h @@ -0,0 +1,52 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#ifndef FEATURE_H +#define FEATURE_H + +// these features have been report to outside (milvus); pls sync the feature code when it needs to be changed. +namespace knowhere::feature { +// vector datatype support : binary +constexpr uint64_t BINARY = 1UL << 0; +// vector datatype support : float32 +constexpr uint64_t FLOAT32 = 1UL << 1; +// vector datatype support : fp16 +constexpr uint64_t FP16 = 1UL << 2; +// vector datatype support : bf16 +constexpr uint64_t BF16 = 1UL << 3; +// vector datatype support : sparse_float32 +constexpr uint64_t SPARSE_FLOAT32 = 1UL << 4; + +// This flag indicates that there is no need to create any index structure (build stage can be skipped) +constexpr uint64_t BF = 1UL << 16; +// This flag indicates that the index defaults to KNN search, meaning the recall rate is 100% +constexpr uint64_t KNN = 1UL << 17; +// This flag indicates that the index is deployed on GPU (need GPU devices) +constexpr uint64_t GPU = 1UL << 18; +// This flag indicates that the index support using mmap manage its mainly memory, which can significant improve the +// capacity +constexpr uint64_t MMAP = 1UL << 19; +// This flag indicates that the index support using materialized view to accelerate filtering search +constexpr uint64_t MV = 1UL << 20; +// This flag indicates that the index need disk during search +constexpr uint64_t DISK = 1UL << 21; + +constexpr uint64_t ALL_TYPE = BINARY | FLOAT32 | FP16 | BF16 | SPARSE_FLOAT32; +constexpr uint64_t ALL_DENSE_TYPE = BINARY | FLOAT32 | FP16 | BF16; +constexpr uint64_t ALL_DENSE_FLOAT_TYPE = FLOAT32 | FP16 | BF16; + +constexpr uint64_t GPU_KNN_FLOAT_INDEX = FLOAT32 | GPU | KNN; +constexpr uint64_t GPU_ANN_FLOAT_INDEX = FLOAT32 | GPU; +} // namespace knowhere::feature +#endif /* FEATURE_H */ diff --git a/include/knowhere/index/index.h b/include/knowhere/index/index.h index 9c3992f07..679a934a8 100644 --- a/include/knowhere/index/index.h +++ b/include/knowhere/index/index.h @@ -193,8 +193,10 @@ class Index { if (node == nullptr) return; node->DecRef(); - if (!node->Ref()) + if (!node->Ref()) { delete node; + node = nullptr; + } } private: diff --git a/include/knowhere/index/index_factory.h b/include/knowhere/index/index_factory.h index 501904f69..185953f7e 100644 --- a/include/knowhere/index/index_factory.h +++ b/include/knowhere/index/index_factory.h @@ -25,13 +25,18 @@ class IndexFactory { public: template expected> - Create(const std::string& name, const int32_t& version, const Object& object = nullptr); + Create(const std::string& name, const int32_t& version, const Object& object = nullptr, bool runtimeCheck = true); template const IndexFactory& - Register(const std::string& name, std::function(const int32_t&, const Object&)> func); + Register(const std::string& name, std::function(const int32_t&, const Object&)> func, + const uint64_t features); static IndexFactory& Instance(); typedef std::tuple>, std::set> GlobalIndexTable; + bool + FeatureCheck(const std::string& name, uint64_t feature) const; + static const std::map& + GetIndexFeatures(); static GlobalIndexTable& StaticIndexTableInstance(); @@ -47,36 +52,88 @@ class IndexFactory { std::function fun_value; }; typedef std::map> FuncMap; + typedef std::map FeatureMap; IndexFactory(); static FuncMap& MapInstance(); + static FeatureMap& + FeatureMapInstance(); }; #define KNOWHERE_CONCAT(x, y) index_factory_ref_##x##y -#define KNOWHERE_REGISTER_GLOBAL(name, func, data_type) \ - const IndexFactory& KNOWHERE_CONCAT(name, data_type) = IndexFactory::Instance().Register(#name, func) -#define KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, data_type, ...) \ +#define KNOWHERE_REGISTER_GLOBAL(name, func, data_type, condition, features) \ + const IndexFactory& KNOWHERE_CONCAT(name, data_type) = \ + condition ? IndexFactory::Instance().Register(#name, func, features) : IndexFactory::Instance(); + +#define KNOWHERE_REGISTER_FUNC_GLOBAL(name, func, data_type, features) \ + KNOWHERE_REGISTER_GLOBAL(name, func, data_type, typeCheck(features), features) + +#define KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, data_type, features, ...) \ KNOWHERE_REGISTER_GLOBAL( \ name, \ (static_cast> (*)(const int32_t&, const Object&)>( \ &Index>::Create)), \ - data_type) -#define KNOWHERE_MOCK_REGISTER_GLOBAL(name, index_node, data_type, ...) \ + data_type, typeCheck(features), features) + +#define KNOWHERE_MOCK_REGISTER_GLOBAL(name, index_node, data_type, features, ...) \ KNOWHERE_REGISTER_GLOBAL( \ name, \ [](const int32_t& version, const Object& object) { \ return (Index>::Create( \ std::make_unique::type, ##__VA_ARGS__>>(version, object))); \ }, \ - data_type) -#define KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(name, index_node, data_type, thread_size) \ - KNOWHERE_REGISTER_GLOBAL( \ + data_type, typeCheck(features), features) + +#define KNOWHERE_SIMPLE_REGISTER_ALL_GLOBAL(name, index_node, features, ...) \ + KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bin1, (features | knowhere::feature::ALL_TYPE), ##__VA_ARGS__); \ + KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bf16, (features | knowhere::feature::ALL_TYPE), ##__VA_ARGS__); \ + KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp16, (features | knowhere::feature::ALL_TYPE), ##__VA_ARGS__); \ + KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp32, (features | knowhere::feature::ALL_TYPE), ##__VA_ARGS__); + +#define KNOWHERE_SIMPLE_REGISTER_SPARSE_FLOAT_GLOBAL(name, index_node, features, ...) \ + KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp32, (features | knowhere::feature::SPARSE_FLOAT32), \ + ##__VA_ARGS__); + +#define KNOWHERE_SIMPLE_REGISTER_DENSE_ALL_GLOBAL(name, index_node, features, ...) \ + KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bin1, (features | knowhere::feature::ALL_DENSE_TYPE), \ + ##__VA_ARGS__); \ + KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bf16, (features | knowhere::feature::ALL_DENSE_TYPE), \ + ##__VA_ARGS__); \ + KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp16, (features | knowhere::feature::ALL_DENSE_TYPE), \ + ##__VA_ARGS__); \ + KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp32, (features | knowhere::feature::ALL_DENSE_TYPE), \ + ##__VA_ARGS__); + +#define KNOWHERE_SIMPLE_REGISTER_DENSE_BIN_GLOBAL(name, index_node, features, ...) \ + KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bin1, (features | knowhere::feature::BINARY), ##__VA_ARGS__); + +#define KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT32_GLOBAL(name, index_node, features, ...) \ + KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp32, (features | knowhere::feature::FLOAT32), ##__VA_ARGS__); + +#define KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(name, index_node, features, ...) \ + KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bf16, (features | knowhere::feature::ALL_DENSE_FLOAT_TYPE), \ + ##__VA_ARGS__); \ + KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp16, (features | knowhere::feature::ALL_DENSE_FLOAT_TYPE), \ + ##__VA_ARGS__); \ + KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp32, (features | knowhere::feature::ALL_DENSE_FLOAT_TYPE), \ + ##__VA_ARGS__); + +#define KNOWHERE_MOCK_REGISTER_DENSE_FLOAT_ALL_GLOBAL(name, index_node, features, ...) \ + KNOWHERE_MOCK_REGISTER_GLOBAL(name, index_node, bf16, (features | knowhere::feature::ALL_DENSE_FLOAT_TYPE), \ + ##__VA_ARGS__); \ + KNOWHERE_MOCK_REGISTER_GLOBAL(name, index_node, fp16, (features | knowhere::feature::ALL_DENSE_FLOAT_TYPE), \ + ##__VA_ARGS__); \ + KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp32, (features | knowhere::feature::ALL_DENSE_FLOAT_TYPE), \ + ##__VA_ARGS__); + +#define KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(name, index_node, data_type, features, thread_size) \ + KNOWHERE_REGISTER_FUNC_GLOBAL( \ name, \ [](const int32_t& version, const Object& object) { \ return (Index::Create( \ std::make_unique::type>>(version, object), thread_size)); \ }, \ - data_type) + data_type, features) #define KNOWHERE_SET_STATIC_GLOBAL_INDEX_TABLE(table_index, name, index_table) \ static int name = []() -> int { \ auto& static_index_table = std::get(IndexFactory::StaticIndexTableInstance()); \ diff --git a/include/knowhere/index/index_node.h b/include/knowhere/index/index_node.h index f3364b6f0..5c82ce723 100644 --- a/include/knowhere/index/index_node.h +++ b/include/knowhere/index/index_node.h @@ -241,6 +241,11 @@ class IndexNode : public Object { virtual bool HasRawData(const std::string& metric_type) const = 0; + virtual Status + ConfigCheck(const Config& cfg, PARAM_TYPE paramType, std::string& msg) const { + return knowhere::Status::success; + } + virtual bool IsAdditionalScalarSupported() const { return false; diff --git a/include/knowhere/operands.h b/include/knowhere/operands.h index 6db211ef7..0db13a3df 100644 --- a/include/knowhere/operands.h +++ b/include/knowhere/operands.h @@ -19,6 +19,8 @@ #include #include +#include "feature.h" + namespace { union fp32_bits { uint32_t as_bits; @@ -143,6 +145,25 @@ struct bf16 { } }; +template +bool +typeCheck(uint64_t features) { + if constexpr (std::is_same_v) { + return features & knowhere::feature::BINARY; + } + if constexpr (std::is_same_v) { + return features & knowhere::feature::FP16; + } + if constexpr (std::is_same_v) { + return features & knowhere::feature::BF16; + } + // TODO : add sparse_fp32 data type + if constexpr (std::is_same_v) { + return (features & knowhere::feature::FLOAT32) || (features & knowhere::feature::SPARSE_FLOAT32); + } + return false; +} + template using TypeMatch = std::bool_constant<(... | std::is_same_v)>; template diff --git a/include/knowhere/utils.h b/include/knowhere/utils.h index 866942279..3bf6a62cf 100644 --- a/include/knowhere/utils.h +++ b/include/knowhere/utils.h @@ -14,6 +14,7 @@ #include #include +#include #include #include "knowhere/binaryset.h" @@ -186,6 +187,9 @@ ConvertIVFFlat(const BinarySet& binset, const MetricType metric_type, const uint bool UseDiskLoad(const std::string& index_type, const int32_t& /*version*/); +bool +ParamCheck(const std::string& index_type, const std::map& config); + template static void writeBinaryPOD(W& out, const T& podRef) { diff --git a/src/common/config.cc b/src/common/config.cc index 81d3b80f9..ca275a009 100644 --- a/src/common/config.cc +++ b/src/common/config.cc @@ -20,6 +20,7 @@ #include "index/hnsw/hnsw_config.h" #include "index/ivf/ivf_config.h" #include "index/sparse/sparse_inverted_index_config.h" +#include "knowhere/index/index_factory.h" #include "knowhere/log.h" namespace knowhere { @@ -119,54 +120,34 @@ Config::FormatAndCheck(const Config& cfg, Json& json, std::string* const err_msg return Status::success; } -} // namespace knowhere - -extern "C" __attribute__((visibility("default"))) int -CheckConfig(int index_type, char const* str, int n, int param_type); - -int -CheckConfig(int index_type, const char* str, int n, int param_type) { - if (!str || n <= 0) { - return int(knowhere::Status::invalid_args); - } - knowhere::Json json = knowhere::Json::parse(str, str + n); - std::unique_ptr cfg; - - switch (index_type) { - case 0: - cfg = std::make_unique(); - break; - case 1: - cfg = std::make_unique(); - break; - case 2: - cfg = std::make_unique(); - break; - case 3: - cfg = std::make_unique(); - break; - case 4: - cfg = std::make_unique(); - break; - case 5: - cfg = std::make_unique(); - break; - case 6: - cfg = std::make_unique(); - break; - case 7: - cfg = std::make_unique(); - break; - case 8: - cfg = std::make_unique(); - break; - default: - return int(knowhere::Status::invalid_args); +template +knowhere::Status +CheckConfig(const std::string& index_type, const int32_t& version, knowhere::Json& json, + knowhere::PARAM_TYPE param_type, std::string& msg) { + auto index = knowhere::IndexFactory::Instance().Create(index_type, version, nullptr, false); + if (!index.has_value()) { + msg = index.what(); + return index.error(); } - - auto res = knowhere::Config::FormatAndCheck(*cfg, json, nullptr); + auto cfg = index.value().Node()->CreateConfig(); + auto res = knowhere::Config::FormatAndCheck(*cfg, json, &msg); if (res != knowhere::Status::success) { - return int(res); + return res; } - return int(knowhere::Config::Load(*cfg, json, knowhere::PARAM_TYPE(param_type), nullptr)); + return knowhere::Config::Load(*cfg, json, knowhere::PARAM_TYPE(param_type), &msg); } + +template knowhere::Status +CheckConfig(const std::string& index_type, const int32_t& version, knowhere::Json& json, + knowhere::PARAM_TYPE param_type, std::string& msg); +template knowhere::Status +CheckConfig(const std::string& index_type, const int32_t& version, knowhere::Json& json, + knowhere::PARAM_TYPE param_type, std::string& msg); +template knowhere::Status +CheckConfig(const std::string& index_type, const int32_t& version, knowhere::Json& json, + knowhere::PARAM_TYPE param_type, std::string& msg); +template knowhere::Status +CheckConfig(const std::string& index_type, const int32_t& version, knowhere::Json& json, + knowhere::PARAM_TYPE param_type, std::string& msg); + +} // namespace knowhere diff --git a/src/index/diskann/diskann.cc b/src/index/diskann/diskann.cc index 11be5d8e0..66e2876f7 100644 --- a/src/index/diskann/diskann.cc +++ b/src/index/diskann/diskann.cc @@ -22,6 +22,7 @@ #include "knowhere/comp/thread_pool.h" #include "knowhere/dataset.h" #include "knowhere/expected.h" +#include "knowhere/feature.h" #include "knowhere/file_manager.h" #include "knowhere/index/index_factory.h" #include "knowhere/log.h" @@ -38,10 +39,12 @@ class DiskANNIndexNode : public IndexNode { public: using DistType = float; DiskANNIndexNode(const int32_t& version, const Object& object) : is_prepared_(false), dim_(-1), count_(-1) { - assert(typeid(object) == typeid(Pack>)); - auto diskann_index_pack = dynamic_cast>*>(&object); - assert(diskann_index_pack != nullptr); - file_manager_ = diskann_index_pack->GetPack(); + if (typeid(object) == typeid(Pack>)) { + auto disk_file_pack = dynamic_cast>*>(&object); + if (disk_file_pack != nullptr) { + file_manager_ = disk_file_pack->GetPack(); + } + } } Status @@ -364,6 +367,8 @@ DiskANNIndexNode::Deserialize(const BinarySet& binset, const Config& c } }(); + assert(file_manager_ != nullptr); + // Load file from file manager. for (auto& filename : GetNecessaryFilenames( index_prefix_, need_norm, prep_conf.search_cache_budget_gb.value() > 0 && !prep_conf.use_bfs_cache.value(), @@ -689,12 +694,8 @@ DiskANNIndexNode::GetCachedNodeNum(const float cache_dram_budget, cons } #ifdef KNOWHERE_WITH_CARDINAL -KNOWHERE_SIMPLE_REGISTER_GLOBAL(DISKANN_DEPRECATED, DiskANNIndexNode, fp32); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(DISKANN_DEPRECATED, DiskANNIndexNode, fp16); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(DISKANN_DEPRECATED, DiskANNIndexNode, bf16); +KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(DISKANN_DEPRECATED, DiskANNIndexNode, knowhere::feature::DISK) #else -KNOWHERE_SIMPLE_REGISTER_GLOBAL(DISKANN, DiskANNIndexNode, fp32); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(DISKANN, DiskANNIndexNode, fp16); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(DISKANN, DiskANNIndexNode, bf16); +KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(DISKANN, DiskANNIndexNode, knowhere::feature::DISK) #endif } // namespace knowhere diff --git a/src/index/diskann/diskann_config.h b/src/index/diskann/diskann_config.h index 660e1902d..6be63e30b 100644 --- a/src/index/diskann/diskann_config.h +++ b/src/index/diskann/diskann_config.h @@ -33,6 +33,8 @@ class DiskANNConfig : public BaseConfig { // complexity. Plz set this value larger than the max_degree unless you need to build indices really quickly and can // somewhat compromise on quality. CFG_INT search_list_size; + + CFG_FLOAT pq_code_budget_gb_ratio; // Limit the size of the PQ code after the raw vector has been PQ-encoded. PQ code is a (pq_code_budget_gb * 1024 * // 1024 * 1024) / row_num)-dimensional uint8 vector. If pq_code_budget_gb is too large, it will be adjusted to the // size of dim*row_num. @@ -50,6 +52,9 @@ class DiskANNConfig : public BaseConfig { // This is the flag to enable fast build, in which we will not build vamana graph by full 2 round. This can // accelerate index build ~30% with an ~1% recall regression. CFG_BOOL accelerate_build; + + CFG_FLOAT search_cache_budget_gb_ratio; + // While serving the index, the entire graph is stored on SSD. For faster search performance, you can cache a few // frequently accessed nodes in memory. CFG_FLOAT search_cache_budget_gb; @@ -86,12 +91,19 @@ class DiskANNConfig : public BaseConfig { .for_search() .for_range_search() .for_iterator(); + KNOWHERE_CONFIG_DECLARE_FIELD(pq_code_budget_gb_ratio) + .description("the size of PQ compared with vector field data") + .set_default(0) + .set_range(0, std::numeric_limits::max()) + .for_train(); KNOWHERE_CONFIG_DECLARE_FIELD(pq_code_budget_gb) .description("the size of PQ compressed representation in GB.") + .set_default(0) .set_range(0, std::numeric_limits::max()) .for_train(); KNOWHERE_CONFIG_DECLARE_FIELD(build_dram_budget_gb) .description("limit on the memory allowed for building the index in GB.") + .set_default(0) .set_range(0, std::numeric_limits::max()) .for_train(); KNOWHERE_CONFIG_DECLARE_FIELD(disk_pq_dims) @@ -102,6 +114,12 @@ class DiskANNConfig : public BaseConfig { .description("a flag to enbale fast build.") .set_default(false) .for_train(); + KNOWHERE_CONFIG_DECLARE_FIELD(search_cache_budget_gb_ratio) + .description("the size of cached nodes compared with vector field data") + .set_default(0) + .set_range(0, std::numeric_limits::max()) + .for_train() + .for_deserialize(); KNOWHERE_CONFIG_DECLARE_FIELD(search_cache_budget_gb) .description("the size of cached nodes in GB.") .set_default(0) @@ -148,6 +166,10 @@ class DiskANNConfig : public BaseConfig { if (!search_list_size.has_value()) { search_list_size = kDefaultSearchListSizeForBuild; } + pq_code_budget_gb = + std::max(pq_code_budget_gb.value(), pq_code_budget_gb_ratio.value() * vec_field_size_gb.value()); + search_cache_budget_gb = std::max(search_cache_budget_gb.value(), + search_cache_budget_gb_ratio.value() * vec_field_size_gb.value()); break; } case PARAM_TYPE::SEARCH: { diff --git a/src/index/flat/flat.cc b/src/index/flat/flat.cc index c65c458fd..e13e5b6d8 100644 --- a/src/index/flat/flat.cc +++ b/src/index/flat/flat.cc @@ -18,6 +18,7 @@ #include "io/memory_io.h" #include "knowhere/bitsetview_idselector.h" #include "knowhere/comp/thread_pool.h" +#include "knowhere/feature.h" #include "knowhere/index/index_factory.h" #include "knowhere/index/index_node_data_mock_wrapper.h" #include "knowhere/log.h" @@ -373,9 +374,14 @@ class FlatIndexNode : public IndexNode { std::shared_ptr search_pool_; }; -KNOWHERE_SIMPLE_REGISTER_GLOBAL(FLAT, FlatIndexNode, fp32, faiss::IndexFlat); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(BINFLAT, FlatIndexNode, bin1, faiss::IndexBinaryFlat); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(BIN_FLAT, FlatIndexNode, bin1, faiss::IndexBinaryFlat); -KNOWHERE_MOCK_REGISTER_GLOBAL(FLAT, FlatIndexNode, fp16, faiss::IndexFlat); -KNOWHERE_MOCK_REGISTER_GLOBAL(FLAT, FlatIndexNode, bf16, faiss::IndexFlat); +KNOWHERE_MOCK_REGISTER_DENSE_FLOAT_ALL_GLOBAL(FLAT, FlatIndexNode, + knowhere::feature::BF | knowhere::feature::KNN | knowhere::feature::MMAP, + faiss::IndexFlat); + +KNOWHERE_SIMPLE_REGISTER_DENSE_BIN_GLOBAL(BINFLAT, FlatIndexNode, + knowhere::feature::BF | knowhere::feature::KNN | knowhere::feature::MMAP, + faiss::IndexBinaryFlat); +KNOWHERE_SIMPLE_REGISTER_DENSE_BIN_GLOBAL(BIN_FLAT, FlatIndexNode, + knowhere::feature::BF | knowhere::feature::KNN | knowhere::feature::MMAP, + faiss::IndexBinaryFlat); } // namespace knowhere diff --git a/src/index/gpu/flat_gpu/flat_gpu.cc b/src/index/gpu/flat_gpu/flat_gpu.cc index e01862a00..52df03ec5 100644 --- a/src/index/gpu/flat_gpu/flat_gpu.cc +++ b/src/index/gpu/flat_gpu/flat_gpu.cc @@ -189,6 +189,4 @@ class GpuFlatIndexNode : public IndexNode { mutable ResWPtr res_; std::unique_ptr index_; }; - -KNOWHERE_SIMPLE_REGISTER_GLOBAL(GPU_FAISS_FLAT, GpuFlatIndexNode, fp32); } // namespace knowhere diff --git a/src/index/gpu/ivf_gpu/ivf_gpu.cc b/src/index/gpu/ivf_gpu/ivf_gpu.cc index c974df2fa..0828c3b6c 100644 --- a/src/index/gpu/ivf_gpu/ivf_gpu.cc +++ b/src/index/gpu/ivf_gpu/ivf_gpu.cc @@ -272,8 +272,4 @@ class GpuIvfIndexNode : public IndexNode { mutable ResWPtr res_; std::unique_ptr index_; }; - -KNOWHERE_SIMPLE_REGISTER_GLOBAL(GPU_FAISS_IVF_FLAT, GpuIvfIndexNode, fp32, faiss::IndexIVFFlat); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(GPU_FAISS_IVF_PQ, GpuIvfIndexNode, fp32, faiss::IndexIVFPQ); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(GPU_FAISS_IVF_SQ8, GpuIvfIndexNode, fp32, faiss::IndexIVFScalarQuantizer); } // namespace knowhere diff --git a/src/index/gpu_raft/gpu_raft.h b/src/index/gpu_raft/gpu_raft.h index 3c529449d..5bf605a61 100644 --- a/src/index/gpu_raft/gpu_raft.h +++ b/src/index/gpu_raft/gpu_raft.h @@ -196,7 +196,7 @@ struct GpuRaftIndexNode : public IndexNode { } Status - DeserializeFromFile(const std::string& filename, const Config& config) { + DeserializeFromFile(const std::string& filename, const Config& config) override { return Status::not_implemented; } diff --git a/src/index/gpu_raft/gpu_raft_brute_force.cc b/src/index/gpu_raft/gpu_raft_brute_force.cc index 7e4200567..2626fe408 100644 --- a/src/index/gpu_raft/gpu_raft_brute_force.cc +++ b/src/index/gpu_raft/gpu_raft_brute_force.cc @@ -24,14 +24,16 @@ #include "raft/util/cuda_rt_essentials.hpp" namespace knowhere { -KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_RAFT_BRUTE_FORCE, GpuRaftBruteForceIndexNode, fp32, []() { - int count; - RAFT_CUDA_TRY(cudaGetDeviceCount(&count)); - return count * cuda_concurrent_size_per_device; -}()); -KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_BRUTE_FORCE, GpuRaftBruteForceIndexNode, fp32, []() { - int count; - RAFT_CUDA_TRY(cudaGetDeviceCount(&count)); - return count * cuda_concurrent_size_per_device; -}()); +KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_RAFT_BRUTE_FORCE, GpuRaftBruteForceIndexNode, fp32, + knowhere::feature::GPU_KNN_FLOAT_INDEX, []() { + int count; + RAFT_CUDA_TRY(cudaGetDeviceCount(&count)); + return count * cuda_concurrent_size_per_device; + }()); +KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_BRUTE_FORCE, GpuRaftBruteForceIndexNode, fp32, + knowhere::feature::GPU_KNN_FLOAT_INDEX, []() { + int count; + RAFT_CUDA_TRY(cudaGetDeviceCount(&count)); + return count * cuda_concurrent_size_per_device; + }()); } // namespace knowhere diff --git a/src/index/gpu_raft/gpu_raft_brute_force_config.h b/src/index/gpu_raft/gpu_raft_brute_force_config.h index ccc36f7a8..79b2b595a 100644 --- a/src/index/gpu_raft/gpu_raft_brute_force_config.h +++ b/src/index/gpu_raft/gpu_raft_brute_force_config.h @@ -23,7 +23,20 @@ namespace knowhere { -struct GpuRaftBruteForceConfig : public BaseConfig {}; +struct GpuRaftBruteForceConfig : public BaseConfig { + Status + CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override { + if (param_type == PARAM_TYPE::TRAIN) { + auto legal_metric_list = std::vector{"L2", "IP"}; + std::string metric = metric_type.value(); + if (std::find(legal_metric_list.begin(), legal_metric_list.end(), metric) == legal_metric_list.end()) { + *err_msg = "metric type " + metric + " not found or not supported, supported: [L2 IP]"; + return Status::invalid_metric_type; + } + } + return Status::success; + } +}; [[nodiscard]] inline auto to_raft_knowhere_config(GpuRaftBruteForceConfig const& cfg) { diff --git a/src/index/gpu_raft/gpu_raft_cagra.cc b/src/index/gpu_raft/gpu_raft_cagra.cc index 240052eef..e5f05a60c 100644 --- a/src/index/gpu_raft/gpu_raft_cagra.cc +++ b/src/index/gpu_raft/gpu_raft_cagra.cc @@ -151,7 +151,7 @@ class GpuRaftCagraHybridIndexNode : public GpuRaftCagraIndexNode { } Status - DeserializeFromFile(const std::string& filename, const Config& config) { + DeserializeFromFile(const std::string& filename, const Config& config) override { return Status::not_implemented; } @@ -160,14 +160,16 @@ class GpuRaftCagraHybridIndexNode : public GpuRaftCagraIndexNode { std::unique_ptr> hnsw_index_ = nullptr; }; -KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_RAFT_CAGRA, GpuRaftCagraHybridIndexNode, fp32, []() { - int count; - RAFT_CUDA_TRY(cudaGetDeviceCount(&count)); - return count * cuda_concurrent_size_per_device; -}()); -KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_CAGRA, GpuRaftCagraHybridIndexNode, fp32, []() { - int count; - RAFT_CUDA_TRY(cudaGetDeviceCount(&count)); - return count * cuda_concurrent_size_per_device; -}()); +KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_RAFT_CAGRA, GpuRaftCagraHybridIndexNode, fp32, + knowhere::feature::FLOAT32 | knowhere::feature::GPU, []() { + int count; + RAFT_CUDA_TRY(cudaGetDeviceCount(&count)); + return count * cuda_concurrent_size_per_device; + }()); +KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_CAGRA, GpuRaftCagraHybridIndexNode, fp32, + knowhere::feature::FLOAT32 | knowhere::feature::GPU, []() { + int count; + RAFT_CUDA_TRY(cudaGetDeviceCount(&count)); + return count * cuda_concurrent_size_per_device; + }()); } // namespace knowhere diff --git a/src/index/gpu_raft/gpu_raft_cagra_config.h b/src/index/gpu_raft/gpu_raft_cagra_config.h index b9a17c193..ad7de43ae 100644 --- a/src/index/gpu_raft/gpu_raft_cagra_config.h +++ b/src/index/gpu_raft/gpu_raft_cagra_config.h @@ -71,7 +71,7 @@ struct GpuRaftCagraConfig : public BaseConfig { KNOWHERE_CONFIG_DECLARE_FIELD(max_queries).description("maximum batch size").set_default(0).for_search(); KNOWHERE_CONFIG_DECLARE_FIELD(build_algo) .description("algorithm used to build knn graph") - .set_default("IVF_PQ") + .set_default("NN_DESCENT") .for_train(); KNOWHERE_CONFIG_DECLARE_FIELD(search_algo) .description("algorithm used for search") diff --git a/src/index/gpu_raft/gpu_raft_ivf_flat.cc b/src/index/gpu_raft/gpu_raft_ivf_flat.cc index 7b63cdb23..24c69e4b2 100644 --- a/src/index/gpu_raft/gpu_raft_ivf_flat.cc +++ b/src/index/gpu_raft/gpu_raft_ivf_flat.cc @@ -23,14 +23,16 @@ #include "knowhere/index/index_node_thread_pool_wrapper.h" #include "raft/util/cuda_rt_essentials.hpp" namespace knowhere { -KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_RAFT_IVF_FLAT, GpuRaftIvfFlatIndexNode, fp32, []() { - int count; - RAFT_CUDA_TRY(cudaGetDeviceCount(&count)); - return count * cuda_concurrent_size_per_device; -}()); -KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_IVF_FLAT, GpuRaftIvfFlatIndexNode, fp32, []() { - int count; - RAFT_CUDA_TRY(cudaGetDeviceCount(&count)); - return count * cuda_concurrent_size_per_device; -}()); +KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_RAFT_IVF_FLAT, GpuRaftIvfFlatIndexNode, fp32, + knowhere::feature::GPU_ANN_FLOAT_INDEX, []() { + int count; + RAFT_CUDA_TRY(cudaGetDeviceCount(&count)); + return count * cuda_concurrent_size_per_device; + }()); +KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_IVF_FLAT, GpuRaftIvfFlatIndexNode, fp32, + knowhere::feature::GPU_ANN_FLOAT_INDEX, []() { + int count; + RAFT_CUDA_TRY(cudaGetDeviceCount(&count)); + return count * cuda_concurrent_size_per_device; + }()); } // namespace knowhere diff --git a/src/index/gpu_raft/gpu_raft_ivf_flat_config.h b/src/index/gpu_raft/gpu_raft_ivf_flat_config.h index 4d9eed752..80fa9e52e 100644 --- a/src/index/gpu_raft/gpu_raft_ivf_flat_config.h +++ b/src/index/gpu_raft/gpu_raft_ivf_flat_config.h @@ -57,6 +57,19 @@ struct GpuRaftIvfFlatConfig : public IvfFlatConfig { .set_default(false) .for_train(); } + + Status + CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override { + if (param_type == PARAM_TYPE::TRAIN) { + auto legal_metric_list = std::vector{"L2", "IP"}; + std::string metric = metric_type.value(); + if (std::find(legal_metric_list.begin(), legal_metric_list.end(), metric) == legal_metric_list.end()) { + *err_msg = "metric type " + metric + " not found or not supported, supported: [L2 IP]"; + return Status::invalid_metric_type; + } + } + return Status::success; + } }; [[nodiscard]] inline auto diff --git a/src/index/gpu_raft/gpu_raft_ivf_pq.cc b/src/index/gpu_raft/gpu_raft_ivf_pq.cc index fbed9337e..c1baf66b8 100644 --- a/src/index/gpu_raft/gpu_raft_ivf_pq.cc +++ b/src/index/gpu_raft/gpu_raft_ivf_pq.cc @@ -25,6 +25,7 @@ namespace knowhere { KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_RAFT_IVF_PQ, GpuRaftIvfPqIndexNode, fp32, + knowhere::feature::GPU_ANN_FLOAT_INDEX, []() { int count; RAFT_CUDA_TRY(cudaGetDeviceCount(&count)); @@ -33,6 +34,7 @@ KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_RAFT_IVF_PQ, GpuRaftIvfPqIndexNode ); KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_IVF_PQ, GpuRaftIvfPqIndexNode, fp32, + knowhere::feature::GPU_ANN_FLOAT_INDEX, []() { int count; RAFT_CUDA_TRY(cudaGetDeviceCount(&count)); diff --git a/src/index/gpu_raft/gpu_raft_ivf_pq_config.h b/src/index/gpu_raft/gpu_raft_ivf_pq_config.h index 402a528e8..346e3721e 100644 --- a/src/index/gpu_raft/gpu_raft_ivf_pq_config.h +++ b/src/index/gpu_raft/gpu_raft_ivf_pq_config.h @@ -92,6 +92,19 @@ struct GpuRaftIvfPqConfig : public IvfPqConfig { .set_default(1.0f) .for_search(); } + + Status + CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override { + if (param_type == PARAM_TYPE::TRAIN) { + auto legal_metric_list = std::vector{"L2", "IP"}; + std::string metric = metric_type.value(); + if (std::find(legal_metric_list.begin(), legal_metric_list.end(), metric) == legal_metric_list.end()) { + *err_msg = "metric type " + metric + " not found or not supported, supported: [L2 IP]"; + return Status::invalid_metric_type; + } + } + return Status::success; + } }; [[nodiscard]] inline auto diff --git a/src/index/hnsw/faiss_hnsw.cc b/src/index/hnsw/faiss_hnsw.cc index fda8b2abf..6072aca6d 100644 --- a/src/index/hnsw/faiss_hnsw.cc +++ b/src/index/hnsw/faiss_hnsw.cc @@ -1461,20 +1461,13 @@ class BaseFaissRegularIndexHNSWPRQNodeTemplate : public BaseFaissRegularIndexHNS }; // -KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_FLAT, BaseFaissRegularIndexHNSWFlatNodeTemplate, fp32); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_FLAT, BaseFaissRegularIndexHNSWFlatNodeTemplate, fp16); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_FLAT, BaseFaissRegularIndexHNSWFlatNodeTemplate, bf16); - -KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_SQ, BaseFaissRegularIndexHNSWSQNodeTemplate, fp32); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_SQ, BaseFaissRegularIndexHNSWSQNodeTemplate, fp16); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_SQ, BaseFaissRegularIndexHNSWSQNodeTemplate, bf16); - -KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_PQ, BaseFaissRegularIndexHNSWPQNodeTemplate, fp32); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_PQ, BaseFaissRegularIndexHNSWPQNodeTemplate, fp16); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_PQ, BaseFaissRegularIndexHNSWPQNodeTemplate, bf16); - -KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_PRQ, BaseFaissRegularIndexHNSWPRQNodeTemplate, fp32); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_PRQ, BaseFaissRegularIndexHNSWPRQNodeTemplate, fp16); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_PRQ, BaseFaissRegularIndexHNSWPRQNodeTemplate, bf16); +KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(FAISS_HNSW_FLAT, BaseFaissRegularIndexHNSWFlatNodeTemplate, + knowhere::feature::MMAP) +KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(FAISS_HNSW_SQ, BaseFaissRegularIndexHNSWSQNodeTemplate, + knowhere::feature::MMAP) +KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(FAISS_HNSW_PQ, BaseFaissRegularIndexHNSWPQNodeTemplate, + knowhere::feature::MMAP) +KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(FAISS_HNSW_PRQ, BaseFaissRegularIndexHNSWPRQNodeTemplate, + knowhere::feature::MMAP) } // namespace knowhere diff --git a/src/index/hnsw/hnsw.cc b/src/index/hnsw/hnsw.cc index 802721e4e..3aec22bda 100644 --- a/src/index/hnsw/hnsw.cc +++ b/src/index/hnsw/hnsw.cc @@ -22,6 +22,7 @@ #include "knowhere/comp/time_recorder.h" #include "knowhere/config.h" #include "knowhere/expected.h" +#include "knowhere/feature.h" #include "knowhere/index/index_factory.h" #include "knowhere/index/index_node_data_mock_wrapper.h" #include "knowhere/log.h" @@ -88,6 +89,31 @@ class HnswIndexNode : public IndexNode { return Status::success; } + virtual Status + ConfigCheck(const Config& cfg, PARAM_TYPE paramType, std::string& msg) const override { + auto hnsw_cfg = static_cast(cfg); + + if (paramType == PARAM_TYPE::TRAIN) { + if constexpr (KnowhereFloatTypeCheck::value) { + if (IsMetricType(hnsw_cfg.metric_type.value(), metric::L2) || + IsMetricType(hnsw_cfg.metric_type.value(), metric::IP) || + IsMetricType(hnsw_cfg.metric_type.value(), metric::COSINE)) { + } else { + msg = "metric type " + hnsw_cfg.metric_type.value() + " not found or not supported, supported: [L2 IP COSINE]"; + return Status::invalid_metric_type; + } + } else { + if (IsMetricType(hnsw_cfg.metric_type.value(), metric::HAMMING) || + IsMetricType(hnsw_cfg.metric_type.value(), metric::JACCARD)) { + } else { + msg = "metric type " + hnsw_cfg.metric_type.value() + " not found or not supported, supported: [HAMMING JACCARD]"; + return Status::invalid_metric_type; + } + } + } + return Status::success; + } + Status Add(const DataSetPtr dataset, const Config& cfg) override { if (!index_) { @@ -586,21 +612,12 @@ class HnswIndexNode : public IndexNode { }; #ifdef KNOWHERE_WITH_CARDINAL -KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_DEPRECATED, HnswIndexNode, fp32); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_DEPRECATED, HnswIndexNode, fp16); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_DEPRECATED, HnswIndexNode, bf16); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_DEPRECATED, HnswIndexNode, bin1); +KNOWHERE_SIMPLE_REGISTER_DENSE_ALL_GLOBAL(HNSW_DEPRECATED, HnswIndexNode, knowhere::feature::MMAP) #else -KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW, HnswIndexNode, fp32); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW, HnswIndexNode, fp16); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW, HnswIndexNode, bf16); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW, HnswIndexNode, bin1); +KNOWHERE_SIMPLE_REGISTER_DENSE_ALL_GLOBAL(HNSW, HnswIndexNode, knowhere::feature::MMAP) #endif -KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8, HnswIndexNode, fp32, QuantType::SQ8); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, fp32, QuantType::SQ8Refine); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8, HnswIndexNode, fp16, QuantType::SQ8); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, fp16, QuantType::SQ8Refine); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8, HnswIndexNode, bf16, QuantType::SQ8); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, bf16, QuantType::SQ8Refine); +KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW_SQ8, HnswIndexNode, knowhere::feature::MMAP, QuantType::SQ8) +KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, knowhere::feature::MMAP, + QuantType::SQ8Refine) } // namespace knowhere diff --git a/src/index/index_factory.cc b/src/index/index_factory.cc index 8bce34916..2899b52c6 100644 --- a/src/index/index_factory.cc +++ b/src/index/index_factory.cc @@ -43,7 +43,7 @@ checkGpuAvailable(const std::string& name) { template expected> -IndexFactory::Create(const std::string& name, const int32_t& version, const Object& object) { +IndexFactory::Create(const std::string& name, const int32_t& version, const Object& object, bool runtimeCheck) { static_assert(KnowhereDataTypeCheck::value == true); auto& func_mapping_ = MapInstance(); auto key = GetKey(name); @@ -55,7 +55,7 @@ IndexFactory::Create(const std::string& name, const int32_t& version, const Obje auto fun_map_v = (FunMapValue>*)(func_mapping_[key].get()); #ifdef KNOWHERE_WITH_RAFT - if (!checkGpuAvailable(name)) { + if (runtimeCheck && !checkGpuAvailable(name)) { return expected>::Err(Status::cuda_runtime_error, "gpu not available"); } #endif @@ -65,12 +65,21 @@ IndexFactory::Create(const std::string& name, const int32_t& version, const Obje template const IndexFactory& -IndexFactory::Register(const std::string& name, std::function(const int32_t&, const Object&)> func) { +IndexFactory::Register(const std::string& name, std::function(const int32_t&, const Object&)> func, + const uint64_t features) { static_assert(KnowhereDataTypeCheck::value == true); auto& func_mapping_ = MapInstance(); auto key = GetKey(name); assert(func_mapping_.find(key) == func_mapping_.end()); func_mapping_[key] = std::make_unique>>(func); + auto& feature_mapping_ = FeatureMapInstance(); + // Index feature use the raw name + if (feature_mapping_.find(name) == feature_mapping_.end()) { + feature_mapping_[name] = features; + } else { + // All data types should have the same features; please try to avoid breaking this rule. + feature_mapping_[name] = feature_mapping_[name] & features; + } return *this; } @@ -87,31 +96,54 @@ IndexFactory::MapInstance() { static FuncMap func_map; return func_map; } + +IndexFactory::FeatureMap& +IndexFactory::FeatureMapInstance() { + static FeatureMap featureMap; + return featureMap; +} + IndexFactory::GlobalIndexTable& IndexFactory::StaticIndexTableInstance() { static GlobalIndexTable static_index_table; return static_index_table; } +bool +IndexFactory::FeatureCheck(const std::string& name, uint64_t feature) const { + auto& feature_mapping_ = IndexFactory::FeatureMapInstance(); + assert(feature_mapping_.find(name) == feature_mapping_.end()); + return (feature_mapping_[name] & feature) == feature_mapping_[name]; +} + +const std::map& +IndexFactory::GetIndexFeatures() { + return FeatureMapInstance(); +} + } // namespace knowhere // template knowhere::expected> -knowhere::IndexFactory::Create(const std::string&, const int32_t&, const Object&); +knowhere::IndexFactory::Create(const std::string&, const int32_t&, const Object&, bool); template knowhere::expected> -knowhere::IndexFactory::Create(const std::string&, const int32_t&, const Object&); +knowhere::IndexFactory::Create(const std::string&, const int32_t&, const Object&, bool); template knowhere::expected> -knowhere::IndexFactory::Create(const std::string&, const int32_t&, const Object&); +knowhere::IndexFactory::Create(const std::string&, const int32_t&, const Object&, bool); template knowhere::expected> -knowhere::IndexFactory::Create(const std::string&, const int32_t&, const Object&); +knowhere::IndexFactory::Create(const std::string&, const int32_t&, const Object&, bool); template const knowhere::IndexFactory& knowhere::IndexFactory::Register( - const std::string&, std::function(const int32_t&, const Object&)>); + const std::string&, std::function(const int32_t&, const Object&)>, + const uint64_t); template const knowhere::IndexFactory& knowhere::IndexFactory::Register( - const std::string&, std::function(const int32_t&, const Object&)>); + const std::string&, std::function(const int32_t&, const Object&)>, + const uint64_t); template const knowhere::IndexFactory& knowhere::IndexFactory::Register( - const std::string&, std::function(const int32_t&, const Object&)>); + const std::string&, std::function(const int32_t&, const Object&)>, + const uint64_t); template const knowhere::IndexFactory& knowhere::IndexFactory::Register( - const std::string&, std::function(const int32_t&, const Object&)>); + const std::string&, std::function(const int32_t&, const Object&)>, + const uint64_t); diff --git a/src/index/ivf/ivf.cc b/src/index/ivf/ivf.cc index 1eb697f62..198350069 100644 --- a/src/index/ivf/ivf.cc +++ b/src/index/ivf/ivf.cc @@ -27,6 +27,7 @@ #include "knowhere/comp/thread_pool.h" #include "knowhere/dataset.h" #include "knowhere/expected.h" +#include "knowhere/feature.h" #include "knowhere/feder/IVFFlat.h" #include "knowhere/index/index_factory.h" #include "knowhere/index/index_node_data_mock_wrapper.h" @@ -1173,39 +1174,24 @@ IvfIndexNode::DeserializeFromFile(const std::string& filena return Status::success; } // bin1 -KNOWHERE_SIMPLE_REGISTER_GLOBAL(IVFBIN, IvfIndexNode, bin1, faiss::IndexBinaryIVF); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(BIN_IVF_FLAT, IvfIndexNode, bin1, faiss::IndexBinaryIVF); -// fp32 -KNOWHERE_SIMPLE_REGISTER_GLOBAL(IVFFLAT, IvfIndexNode, fp32, faiss::IndexIVFFlat); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(IVF_FLAT, IvfIndexNode, fp32, faiss::IndexIVFFlat); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(IVFFLATCC, IvfIndexNode, fp32, faiss::IndexIVFFlatCC); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(IVF_FLAT_CC, IvfIndexNode, fp32, faiss::IndexIVFFlatCC); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(SCANN, IvfIndexNode, fp32, faiss::IndexScaNN); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(IVFPQ, IvfIndexNode, fp32, faiss::IndexIVFPQ); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(IVF_PQ, IvfIndexNode, fp32, faiss::IndexIVFPQ); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(IVFSQ, IvfIndexNode, fp32, faiss::IndexIVFScalarQuantizer); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(IVF_SQ8, IvfIndexNode, fp32, faiss::IndexIVFScalarQuantizer); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(IVF_SQ_CC, IvfIndexNode, fp32, faiss::IndexIVFScalarQuantizerCC); -// fp16 -KNOWHERE_MOCK_REGISTER_GLOBAL(IVFFLAT, IvfIndexNode, fp16, faiss::IndexIVFFlat); -KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_FLAT, IvfIndexNode, fp16, faiss::IndexIVFFlat); -KNOWHERE_MOCK_REGISTER_GLOBAL(IVFFLATCC, IvfIndexNode, fp16, faiss::IndexIVFFlatCC); -KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_FLAT_CC, IvfIndexNode, fp16, faiss::IndexIVFFlatCC); -KNOWHERE_MOCK_REGISTER_GLOBAL(SCANN, IvfIndexNode, fp16, faiss::IndexScaNN); -KNOWHERE_MOCK_REGISTER_GLOBAL(IVFPQ, IvfIndexNode, fp16, faiss::IndexIVFPQ); -KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_PQ, IvfIndexNode, fp16, faiss::IndexIVFPQ); -KNOWHERE_MOCK_REGISTER_GLOBAL(IVFSQ, IvfIndexNode, fp16, faiss::IndexIVFScalarQuantizer); -KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_SQ8, IvfIndexNode, fp16, faiss::IndexIVFScalarQuantizer); -KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_SQ_CC, IvfIndexNode, fp16, faiss::IndexIVFScalarQuantizerCC); -// bf16 -KNOWHERE_MOCK_REGISTER_GLOBAL(IVFFLAT, IvfIndexNode, bf16, faiss::IndexIVFFlat); -KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_FLAT, IvfIndexNode, bf16, faiss::IndexIVFFlat); -KNOWHERE_MOCK_REGISTER_GLOBAL(IVFFLATCC, IvfIndexNode, bf16, faiss::IndexIVFFlatCC); -KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_FLAT_CC, IvfIndexNode, bf16, faiss::IndexIVFFlatCC); -KNOWHERE_MOCK_REGISTER_GLOBAL(SCANN, IvfIndexNode, bf16, faiss::IndexScaNN); -KNOWHERE_MOCK_REGISTER_GLOBAL(IVFPQ, IvfIndexNode, bf16, faiss::IndexIVFPQ); -KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_PQ, IvfIndexNode, bf16, faiss::IndexIVFPQ); -KNOWHERE_MOCK_REGISTER_GLOBAL(IVFSQ, IvfIndexNode, bf16, faiss::IndexIVFScalarQuantizer); -KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_SQ8, IvfIndexNode, bf16, faiss::IndexIVFScalarQuantizer); -KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_SQ_CC, IvfIndexNode, bf16, faiss::IndexIVFScalarQuantizerCC); +KNOWHERE_SIMPLE_REGISTER_DENSE_BIN_GLOBAL(IVFBIN, IvfIndexNode, knowhere::feature::MMAP, faiss::IndexBinaryIVF) +KNOWHERE_SIMPLE_REGISTER_DENSE_BIN_GLOBAL(BIN_IVF_FLAT, IvfIndexNode, knowhere::feature::MMAP, faiss::IndexBinaryIVF) + +// float +KNOWHERE_MOCK_REGISTER_DENSE_FLOAT_ALL_GLOBAL(IVFFLAT, IvfIndexNode, knowhere::feature::MMAP, faiss::IndexIVFFlat) +KNOWHERE_MOCK_REGISTER_DENSE_FLOAT_ALL_GLOBAL(IVF_FLAT, IvfIndexNode, knowhere::feature::MMAP, faiss::IndexIVFFlat) +KNOWHERE_MOCK_REGISTER_DENSE_FLOAT_ALL_GLOBAL(IVFFLATCC, IvfIndexNode, knowhere::feature::MMAP, faiss::IndexIVFFlatCC) +KNOWHERE_MOCK_REGISTER_DENSE_FLOAT_ALL_GLOBAL(IVF_FLAT_CC, IvfIndexNode, knowhere::feature::MMAP, faiss::IndexIVFFlatCC) +KNOWHERE_MOCK_REGISTER_DENSE_FLOAT_ALL_GLOBAL(SCANN, IvfIndexNode, knowhere::feature::MMAP, faiss::IndexScaNN) +KNOWHERE_MOCK_REGISTER_DENSE_FLOAT_ALL_GLOBAL(IVFPQ, IvfIndexNode, knowhere::feature::MMAP, faiss::IndexIVFPQ) +KNOWHERE_MOCK_REGISTER_DENSE_FLOAT_ALL_GLOBAL(IVF_PQ, IvfIndexNode, knowhere::feature::MMAP, faiss::IndexIVFPQ) +KNOWHERE_MOCK_REGISTER_DENSE_FLOAT_ALL_GLOBAL(IVFSQ, IvfIndexNode, knowhere::feature::MMAP, + faiss::IndexIVFScalarQuantizer) +KNOWHERE_MOCK_REGISTER_DENSE_FLOAT_ALL_GLOBAL(IVF_SQ, IvfIndexNode, knowhere::feature::MMAP, + faiss::IndexIVFScalarQuantizer) +KNOWHERE_MOCK_REGISTER_DENSE_FLOAT_ALL_GLOBAL(IVF_SQ8, IvfIndexNode, knowhere::feature::MMAP, + faiss::IndexIVFScalarQuantizer) +KNOWHERE_MOCK_REGISTER_DENSE_FLOAT_ALL_GLOBAL(IVF_SQ_CC, IvfIndexNode, knowhere::feature::MMAP, + faiss::IndexIVFScalarQuantizerCC) + } // namespace knowhere diff --git a/src/index/ivf/ivf_config.h b/src/index/ivf/ivf_config.h index bcc0cc44f..0b829e9f0 100644 --- a/src/index/ivf/ivf_config.h +++ b/src/index/ivf/ivf_config.h @@ -94,6 +94,16 @@ class ScannConfig : public IvfFlatConfig { Status CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override { switch (param_type) { + case PARAM_TYPE::TRAIN: { + // TODO: handle odd dim with scann + if (dim.has_value()) { + int vec_dim = dim.value() / 2; + if (vec_dim % 2 != 0) { + *err_msg = "dimension must be able to be divided by 2, dimension:" + std::to_string(vec_dim); + return Status::invalid_args; + } + } + } case PARAM_TYPE::SEARCH: { if (!faiss::support_pq_fast_scan) { LOG_KNOWHERE_ERROR_ << "SCANN index is not supported on the current CPU model, avx2 support is " @@ -128,7 +138,20 @@ class ScannConfig : public IvfFlatConfig { class IvfSqConfig : public IvfConfig {}; -class IvfBinConfig : public IvfConfig {}; +class IvfBinConfig : public IvfConfig { + Status + CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override { + if (param_type == PARAM_TYPE::TRAIN) { + auto legal_metric_list = std::vector{"HAMMING", "JACCARD"}; + std::string metric = metric_type.value(); + if (std::find(legal_metric_list.begin(), legal_metric_list.end(), metric) == legal_metric_list.end()) { + *err_msg = "metric type " + metric + " not found or not supported, supported: [HAMMING JACCARD]"; + return Status::invalid_metric_type; + } + } + return Status::success; + } +}; class IvfSqCcConfig : public IvfFlatCcConfig { public: @@ -151,6 +174,13 @@ class IvfSqCcConfig : public IvfFlatCcConfig { Status CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override { if (param_type == PARAM_TYPE::TRAIN) { + auto legal_metric_list = std::vector{"HAMMING", "JACCARD"}; + std::string metric = metric_type.value(); + if (std::find(legal_metric_list.begin(), legal_metric_list.end(), metric) == legal_metric_list.end()) { + *err_msg = "metric type" + metric + " not found or not supported, supported [HAMMING, JACCARD]"; + return Status::invalid_metric_type; + } + auto code_size_v = code_size.value(); auto legal_code_size_list = std::vector{4, 6, 8, 16}; if (std::find(legal_code_size_list.begin(), legal_code_size_list.end(), code_size_v) == diff --git a/src/index/sparse/sparse_index_node.cc b/src/index/sparse/sparse_index_node.cc index 9efb4c859..96ad8d9bc 100644 --- a/src/index/sparse/sparse_index_node.cc +++ b/src/index/sparse/sparse_index_node.cc @@ -19,6 +19,7 @@ #include "knowhere/config.h" #include "knowhere/dataset.h" #include "knowhere/expected.h" +#include "knowhere/feature.h" #include "knowhere/index/index_factory.h" #include "knowhere/index/index_node.h" #include "knowhere/log.h" @@ -315,7 +316,8 @@ class SparseInvertedIndexNode : public IndexNode { size_t map_size_ = 0; }; // class SparseInvertedIndexNode -KNOWHERE_SIMPLE_REGISTER_GLOBAL(SPARSE_INVERTED_INDEX, SparseInvertedIndexNode, fp32, /*use_wand=*/false); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(SPARSE_WAND, SparseInvertedIndexNode, fp32, /*use_wand=*/true); - +KNOWHERE_SIMPLE_REGISTER_SPARSE_FLOAT_GLOBAL(SPARSE_INVERTED_INDEX, SparseInvertedIndexNode, knowhere::feature::MMAP, + /*use_wand=*/false) +KNOWHERE_SIMPLE_REGISTER_SPARSE_FLOAT_GLOBAL(SPARSE_WAND, SparseInvertedIndexNode, knowhere::feature::MMAP, + /*use_wand=*/true) } // namespace knowhere diff --git a/tests/ut/test_config.cc b/tests/ut/test_config.cc index 29ed07dc6..3925d7cee 100644 --- a/tests/ut/test_config.cc +++ b/tests/ut/test_config.cc @@ -15,6 +15,8 @@ #include "index/hnsw/hnsw_config.h" #include "index/ivf/ivf_config.h" #include "knowhere/config.h" +#include "knowhere/index/index_factory.h" +#include "knowhere/version.h" #ifdef KNOWHERE_WITH_DISKANN #include "index/diskann/diskann_config.h" #endif @@ -22,6 +24,39 @@ #include "index/gpu_raft/gpu_raft_cagra_config.h" #endif +void +checkBuildConfig(knowhere::IndexType indexType, knowhere::Json& json) { + std::string msg; + if (knowhere::IndexFactory::Instance().FeatureCheck(indexType, knowhere::feature::BINARY)) { + CHECK(knowhere::CheckConfig(indexType, knowhere::Version::GetCurrentVersion().VersionNumber(), + json, knowhere::PARAM_TYPE::TRAIN, + msg) == knowhere::Status::success); + CHECK(msg.empty()); + } + if (knowhere::IndexFactory::Instance().FeatureCheck(indexType, knowhere::feature::FLOAT32)) { + CHECK(knowhere::CheckConfig(indexType, knowhere::Version::GetCurrentVersion().VersionNumber(), json, + knowhere::PARAM_TYPE::TRAIN, msg) == knowhere::Status::success); + CHECK(msg.empty()); + } + if (knowhere::IndexFactory::Instance().FeatureCheck(indexType, knowhere::feature::BF16)) { + CHECK(knowhere::CheckConfig(indexType, knowhere::Version::GetCurrentVersion().VersionNumber(), + json, knowhere::PARAM_TYPE::TRAIN, + msg) == knowhere::Status::success); + CHECK(msg.empty()); + } + if (knowhere::IndexFactory::Instance().FeatureCheck(indexType, knowhere::feature::FP16)) { + CHECK(knowhere::CheckConfig(indexType, knowhere::Version::GetCurrentVersion().VersionNumber(), + json, knowhere::PARAM_TYPE::TRAIN, + msg) == knowhere::Status::success); + CHECK(msg.empty()); + } + if (knowhere::IndexFactory::Instance().FeatureCheck(indexType, knowhere::feature::SPARSE_FLOAT32)) { + CHECK(knowhere::CheckConfig(indexType, knowhere::Version::GetCurrentVersion().VersionNumber(), json, + knowhere::PARAM_TYPE::TRAIN, msg) == knowhere::Status::success); + CHECK(msg.empty()); + } +} + TEST_CASE("Test config json parse", "[config]") { knowhere::Status s; std::string err_msg; @@ -96,9 +131,15 @@ TEST_CASE("Test config json parse", "[config]") { })"); knowhere::HnswConfig hnsw_config; s = knowhere::Config::FormatAndCheck(hnsw_config, large_build_json); + + checkBuildConfig(knowhere::IndexEnum::INDEX_HNSW, large_build_json); + CHECK(s == knowhere::Status::success); #ifdef KNOWHERE_WITH_DISKANN knowhere::DiskANNConfig diskann_config; + + checkBuildConfig(knowhere::IndexEnum::INDEX_DISKANN, large_build_json); + s = knowhere::Config::FormatAndCheck(diskann_config, large_build_json); CHECK(s == knowhere::Status::success); #endif @@ -137,6 +178,7 @@ TEST_CASE("Test config json parse", "[config]") { "k": 100 })"); + checkBuildConfig(knowhere::IndexEnum::INDEX_FAISS_IDMAP, json); knowhere::FlatConfig train_cfg; s = knowhere::Config::Load(train_cfg, json, knowhere::TRAIN); CHECK(s == knowhere::Status::success); @@ -160,6 +202,7 @@ TEST_CASE("Test config json parse", "[config]") { "trace_visit": true })"); knowhere::IvfFlatConfig train_cfg; + checkBuildConfig(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, json); s = knowhere::Config::Load(train_cfg, json, knowhere::TRAIN); CHECK(s == knowhere::Status::success); CHECK(train_cfg.metric_type.value() == "L2"); @@ -202,6 +245,7 @@ TEST_CASE("Test config json parse", "[config]") { knowhere::HnswConfig wrong_cfg; auto invalid_value_json = json; invalid_value_json["efConstruction"] = 100.10; + checkBuildConfig(knowhere::IndexEnum::INDEX_HNSW, json); s = knowhere::Config::Load(wrong_cfg, invalid_value_json, knowhere::TRAIN); CHECK(s == knowhere::Status::type_conflict_in_json); @@ -281,6 +325,7 @@ TEST_CASE("Test config json parse", "[config]") { })"); { knowhere::DiskANNConfig train_cfg; + checkBuildConfig(knowhere::IndexEnum::INDEX_DISKANN, json); s = knowhere::Config::Load(train_cfg, json, knowhere::TRAIN); CHECK(s == knowhere::Status::success); CHECK_EQ(128, train_cfg.search_list_size.value());