diff --git a/include/knowhere/config.h b/include/knowhere/config.h index faf36d6cb..67445ef79 100644 --- a/include/knowhere/config.h +++ b/include/knowhere/config.h @@ -38,6 +38,10 @@ typedef nlohmann::json Json; #define CFG_INT std::optional #endif +#ifndef CFG_INT64 +#define CFG_INT64 std::optional +#endif + #ifndef CFG_STRING #define CFG_STRING std::optional #endif @@ -140,6 +144,31 @@ struct Entry { bool allow_empty_without_default = false; }; +template <> +struct Entry { + explicit Entry(CFG_INT64* v) { + val = v; + default_val = std::nullopt; + type = 0x0; + range = std::nullopt; + desc = std::nullopt; + } + Entry() { + val = nullptr; + default_val = std::nullopt; + type = 0x0; + range = std::nullopt; + desc = std::nullopt; + } + + CFG_INT64* val; + std::optional default_val; + uint32_t type; + std::optional> range; + std::optional desc; + bool allow_empty_without_default = false; +}; + template <> struct Entry { explicit Entry(CFG_BOOL* v) { @@ -317,12 +346,12 @@ class Config { } if (!json[it.first].is_number_integer()) { std::string msg = "Type conflict in json: param '" + it.first + "' (" + to_string(json[it.first]) + - ") should be integer"; + ") should be integer(64bit)"; show_err_msg(msg); return Status::type_conflict_in_json; } if (ptr->range.has_value()) { - if (json[it.first].get() > std::numeric_limits::max()) { + if (json[it.first].get() > std::numeric_limits::max()) { std::string msg = "Arithmetic overflow: param '" + it.first + "' (" + to_string(json[it.first]) + ") should not bigger than " + std::to_string(std::numeric_limits::max()); @@ -346,6 +375,54 @@ class Config { } } + if (const Entry* ptr = std::get_if>(&var)) { + if (!(type & ptr->type)) { + continue; + } + if (json.find(it.first) == json.end()) { + if (!ptr->default_val.has_value()) { + if (ptr->allow_empty_without_default) { + continue; + } + std::string msg = "param '" + it.first + "' not exist in json"; + show_err_msg(msg); + return Status::invalid_param_in_json; + } else { + *ptr->val = ptr->default_val; + continue; + } + } + if (!json[it.first].is_number_integer()) { + std::string msg = "Type conflict in json: param '" + it.first + "' (" + to_string(json[it.first]) + + ") should be unsigned integer"; + show_err_msg(msg); + return Status::type_conflict_in_json; + } + if (ptr->range.has_value()) { + if (json[it.first].get() > std::numeric_limits::max()) { + std::string msg = "Arithmetic overflow: param '" + it.first + "' (" + + to_string(json[it.first]) + ") should not bigger than " + + std::to_string(std::numeric_limits::max()); + show_err_msg(msg); + return Status::arithmetic_overflow; + } + CFG_INT64::value_type v = json[it.first]; + auto range_val = ptr->range.value(); + if (range_val.first <= v && v <= range_val.second) { + *ptr->val = v; + } else { + std::string msg = "Out of range in json: param '" + it.first + "' (" + + to_string(json[it.first]) + ") should be in range [" + + std::to_string(range_val.first) + ", " + std::to_string(range_val.second) + + "]"; + show_err_msg(msg); + return Status::out_of_range_in_json; + } + } else { + *ptr->val = json[it.first]; + } + } + if (const Entry* ptr = std::get_if>(&var)) { if (!(type & ptr->type)) { continue; @@ -478,8 +555,8 @@ class Config { virtual ~Config() { } - using VarEntry = std::variant, Entry, Entry, Entry, - Entry>; + using VarEntry = std::variant, Entry, Entry, Entry, + Entry, Entry>; std::unordered_map __DICT__; protected: @@ -501,7 +578,7 @@ const float defaultRangeFilter = 1.0f / 0.0; class BaseConfig : public Config { public: - CFG_INT dim; // just used for config verify + CFG_INT64 dim; // just used for config verify CFG_STRING metric_type; CFG_INT k; CFG_INT num_build_thread; diff --git a/src/common/config.cc b/src/common/config.cc index 81cec6461..63143de34 100644 --- a/src/common/config.cc +++ b/src/common/config.cc @@ -97,12 +97,17 @@ Config::FormatAndCheck(const Config& cfg, Json& json, std::string* const err_msg } if (v < std::numeric_limits::min() || v > std::numeric_limits::max()) { - *err_msg = "integer value out of range, key: '" + key_str + "', value: '" + value_str + "'"; + if (err_msg) { + *err_msg = + "integer value out of range, key: '" + key_str + "', value: '" + value_str + "'"; + } return knowhere::Status::invalid_value_in_json; } json[key_str] = static_cast(v); } catch (const std::out_of_range&) { - *err_msg = "integer value out of range, key: '" + key_str + "', value: '" + value_str + "'"; + if (err_msg) { + *err_msg = "integer value out of range, key: '" + key_str + "', value: '" + value_str + "'"; + } return knowhere::Status::invalid_value_in_json; } catch (const std::invalid_argument&) { KNOWHERE_THROW_MSG("invalid integer value, key: '" + key_str + "', value: '" + value_str + "'"); diff --git a/src/index/diskann/diskann_config.h b/src/index/diskann/diskann_config.h index db94bd922..a8622b1d1 100644 --- a/src/index/diskann/diskann_config.h +++ b/src/index/diskann/diskann_config.h @@ -182,8 +182,10 @@ class DiskANNConfig : public BaseConfig { if (!search_list_size.has_value()) { search_list_size = std::max(k.value(), kSearchListSizeMinValue); } else if (k.value() > search_list_size.value()) { - *err_msg = "search_list_size(" + std::to_string(search_list_size.value()) + - ") should be larger than k(" + std::to_string(k.value()) + ")"; + if (err_msg) { + *err_msg = "search_list_size(" + std::to_string(search_list_size.value()) + + ") should be larger than k(" + std::to_string(k.value()) + ")"; + } LOG_KNOWHERE_ERROR_ << *err_msg; return Status::out_of_range_in_json; } diff --git a/src/index/gpu_raft/gpu_raft_brute_force_config.h b/src/index/gpu_raft/gpu_raft_brute_force_config.h index 2383aa8ef..e720ba00f 100644 --- a/src/index/gpu_raft/gpu_raft_brute_force_config.h +++ b/src/index/gpu_raft/gpu_raft_brute_force_config.h @@ -30,7 +30,9 @@ struct GpuRaftBruteForceConfig : public BaseConfig { constexpr std::array legal_metric_list{"L2", "IP"}; std::string metric = metric_type.value(); if (std::find(legal_metric_list.begin(), legal_metric_list.end(), metric) == legal_metric_list.end()) { - *err_msg = "metric type " + metric + " not found or not supported, supported: [L2 IP]"; + if (err_msg) { + *err_msg = "metric type " + metric + " not found or not supported, supported: [L2 IP]"; + } return Status::invalid_metric_type; } } diff --git a/src/index/gpu_raft/gpu_raft_cagra_config.h b/src/index/gpu_raft/gpu_raft_cagra_config.h index ad7de43ae..6c03b0255 100644 --- a/src/index/gpu_raft/gpu_raft_cagra_config.h +++ b/src/index/gpu_raft/gpu_raft_cagra_config.h @@ -129,8 +129,10 @@ struct GpuRaftCagraConfig : public BaseConfig { if (search_width.has_value()) { if (std::max(itopk_size.value(), kAlignFactor * search_width.value()) < k.value()) { - *err_msg = "max((itopk_size + 31)// 32, search_width) * 32< topk"; - LOG_KNOWHERE_ERROR_ << *err_msg; + if (err_msg) { + *err_msg = "max((itopk_size + 31)// 32, search_width) * 32< topk"; + LOG_KNOWHERE_ERROR_ << *err_msg; + } return Status::out_of_range_in_json; } } else { diff --git a/src/index/gpu_raft/gpu_raft_ivf_flat_config.h b/src/index/gpu_raft/gpu_raft_ivf_flat_config.h index ffcaf2d17..32e8aea2b 100644 --- a/src/index/gpu_raft/gpu_raft_ivf_flat_config.h +++ b/src/index/gpu_raft/gpu_raft_ivf_flat_config.h @@ -64,7 +64,9 @@ struct GpuRaftIvfFlatConfig : public IvfFlatConfig { constexpr std::array legal_metric_list{"L2", "IP"}; std::string metric = metric_type.value(); if (std::find(legal_metric_list.begin(), legal_metric_list.end(), metric) == legal_metric_list.end()) { - *err_msg = "metric type " + metric + " not found or not supported, supported: [L2 IP]"; + if (err_msg) { + *err_msg = "metric type " + metric + " not found or not supported, supported: [L2 IP]"; + } return Status::invalid_metric_type; } } diff --git a/src/index/gpu_raft/gpu_raft_ivf_pq_config.h b/src/index/gpu_raft/gpu_raft_ivf_pq_config.h index 1b4e4ce9e..744d39988 100644 --- a/src/index/gpu_raft/gpu_raft_ivf_pq_config.h +++ b/src/index/gpu_raft/gpu_raft_ivf_pq_config.h @@ -99,7 +99,9 @@ struct GpuRaftIvfPqConfig : public IvfPqConfig { constexpr std::array legal_metric_list{"L2", "IP"}; std::string metric = metric_type.value(); if (std::find(legal_metric_list.begin(), legal_metric_list.end(), metric) == legal_metric_list.end()) { - *err_msg = "metric type " + metric + " not found or not supported, supported: [L2 IP]"; + if (err_msg) { + *err_msg = "metric type " + metric + " not found or not supported, supported: [L2 IP]"; + } return Status::invalid_metric_type; } } diff --git a/src/index/hnsw/faiss_hnsw_config.h b/src/index/hnsw/faiss_hnsw_config.h index 2bf0f6a80..d5ded26ad 100644 --- a/src/index/hnsw/faiss_hnsw_config.h +++ b/src/index/hnsw/faiss_hnsw_config.h @@ -89,9 +89,11 @@ class FaissHnswConfig : public BaseConfig { if (!ef.has_value()) { ef = std::max(k.value(), kEfMinValue); } else if (k.value() > ef.value()) { - *err_msg = "ef(" + std::to_string(ef.value()) + ") should be larger than k(" + - std::to_string(k.value()) + ")"; - LOG_KNOWHERE_ERROR_ << *err_msg; + if (err_msg) { + *err_msg = "ef(" + std::to_string(ef.value()) + ") should be larger than k(" + + std::to_string(k.value()) + ")"; + LOG_KNOWHERE_ERROR_ << *err_msg; + } return Status::out_of_range_in_json; } break; @@ -140,8 +142,10 @@ class FaissHnswFlatConfig : public FaissHnswConfig { if (param_type == PARAM_TYPE::TRAIN) { // prohibit refine if (refine.value_or(false) || refine_type.has_value() || refine_k.has_value()) { - *err_msg = "refine is not supported for this index"; - LOG_KNOWHERE_ERROR_ << *err_msg; + if (err_msg) { + *err_msg = "refine is not supported for this index"; + LOG_KNOWHERE_ERROR_ << *err_msg; + } return Status::invalid_value_in_json; } } @@ -174,16 +178,20 @@ class FaissHnswSqConfig : public FaissHnswConfig { if (param_type == PARAM_TYPE::TRAIN) { auto sq_type_v = sq_type.value(); if (!WhetherAcceptableQuantType(sq_type_v)) { - *err_msg = "invalid scalar quantizer type"; - LOG_KNOWHERE_ERROR_ << *err_msg; + if (err_msg) { + *err_msg = "invalid scalar quantizer type"; + LOG_KNOWHERE_ERROR_ << *err_msg; + } return Status::invalid_value_in_json; } // check refine if (refine_type.has_value()) { if (!WhetherAcceptableRefineType(refine_type.value())) { - *err_msg = "invalid refine type type"; - LOG_KNOWHERE_ERROR_ << *err_msg; + if (err_msg) { + *err_msg = "invalid refine type type"; + LOG_KNOWHERE_ERROR_ << *err_msg; + } return Status::invalid_value_in_json; } } diff --git a/src/index/hnsw/hnsw_config.h b/src/index/hnsw/hnsw_config.h index d8f82c782..b60626c50 100644 --- a/src/index/hnsw/hnsw_config.h +++ b/src/index/hnsw/hnsw_config.h @@ -59,9 +59,11 @@ class HnswConfig : public BaseConfig { if (!ef.has_value()) { ef = std::max(k.value(), kEfMinValue); } else if (k.value() > ef.value()) { - *err_msg = "ef(" + std::to_string(ef.value()) + ") should be larger than k(" + - std::to_string(k.value()) + ")"; - LOG_KNOWHERE_ERROR_ << *err_msg; + if (err_msg) { + *err_msg = "ef(" + std::to_string(ef.value()) + ") should be larger than k(" + + std::to_string(k.value()) + ")"; + LOG_KNOWHERE_ERROR_ << *err_msg; + } return Status::out_of_range_in_json; } break; diff --git a/src/index/index.cc b/src/index/index.cc index a3b44e702..2e6f7e421 100644 --- a/src/index/index.cc +++ b/src/index/index.cc @@ -92,7 +92,8 @@ template inline Status Index::Train(const DataSetPtr dataset, const Json& json) { auto cfg = this->node->CreateConfig(); - RETURN_IF_ERROR(LoadConfig(cfg.get(), json, knowhere::TRAIN, "Train")); + std::string msg; + RETURN_IF_ERROR(LoadConfig(cfg.get(), json, knowhere::TRAIN, "Train", &msg)); return this->node->Train(dataset, std::move(cfg)); } @@ -100,7 +101,8 @@ template inline Status Index::Add(const DataSetPtr dataset, const Json& json) { auto cfg = this->node->CreateConfig(); - RETURN_IF_ERROR(LoadConfig(cfg.get(), json, knowhere::TRAIN, "Add")); + std::string msg; + RETURN_IF_ERROR(LoadConfig(cfg.get(), json, knowhere::TRAIN, "Add", &msg)); return this->node->Add(dataset, std::move(cfg)); } diff --git a/src/index/ivf/ivf_config.h b/src/index/ivf/ivf_config.h index 5b117e77b..a65b3a27e 100644 --- a/src/index/ivf/ivf_config.h +++ b/src/index/ivf/ivf_config.h @@ -78,9 +78,11 @@ class IvfPqConfig : public IvfConfig { int vec_dim = dim.value(); int param_m = m.value(); if (vec_dim % param_m != 0) { - *err_msg = - "dimension must be able to be divided by `m`, dimension: " + std::to_string(vec_dim) + - ", m: " + std::to_string(param_m); + if (err_msg) { + *err_msg = + "dimension must be able to be divided by `m`, dimension: " + std::to_string(vec_dim) + + ", m: " + std::to_string(param_m); + } return Status::invalid_args; } } @@ -115,7 +117,10 @@ class ScannConfig : public IvfFlatConfig { if (dim.has_value()) { int vec_dim = dim.value(); if (vec_dim % 2 != 0) { - *err_msg = "dimension must be able to be divided by 2, dimension:" + std::to_string(vec_dim); + if (err_msg) { + *err_msg = + "dimension must be able to be divided by 2, dimension:" + std::to_string(vec_dim); + } return Status::invalid_args; } } @@ -161,7 +166,9 @@ class IvfBinConfig : public IvfConfig { constexpr std::array legal_metric_list{"HAMMING", "JACCARD"}; std::string metric = metric_type.value(); if (std::find(legal_metric_list.begin(), legal_metric_list.end(), metric) == legal_metric_list.end()) { - *err_msg = "metric type " + metric + " not found or not supported, supported: [HAMMING JACCARD]"; + if (err_msg) { + *err_msg = "metric type " + metric + " not found or not supported, supported: [HAMMING JACCARD]"; + } return Status::invalid_metric_type; } } @@ -195,9 +202,11 @@ class IvfSqCcConfig : public IvfFlatCcConfig { auto legal_code_size_list = std::vector{4, 6, 8, 16}; if (std::find(legal_code_size_list.begin(), legal_code_size_list.end(), code_size_v) == legal_code_size_list.end()) { - *err_msg = - "compress a vector into (code_size * dim)/8 bytes, code size value should be in 4, 6, 8 and 16"; - LOG_KNOWHERE_ERROR_ << *err_msg; + if (err_msg) { + *err_msg = + "compress a vector into (code_size * dim)/8 bytes, code size value should be in 4, 6, 8 and 16"; + LOG_KNOWHERE_ERROR_ << *err_msg; + } return Status::invalid_value_in_json; } } diff --git a/tests/ut/test_config.cc b/tests/ut/test_config.cc index 318b8b66f..8ded69d8f 100644 --- a/tests/ut/test_config.cc +++ b/tests/ut/test_config.cc @@ -93,6 +93,20 @@ TEST_CASE("Test config json parse", "[config]") { CHECK(s == knowhere::Status::success); } + SECTION("check int64 json values") { + auto unsigned_int_json_str = GENERATE(as{}, + R"({ + "dim": 10000000000 + })"); + knowhere::BaseConfig test_config; + knowhere::Json test_json = knowhere::Json::parse(unsigned_int_json_str); + s = knowhere::Config::FormatAndCheck(test_config, test_json); + CHECK(s == knowhere::Status::success); + s = knowhere::Config::Load(test_config, test_json, knowhere::TRAIN); + CHECK(s == knowhere::Status::success); + CHECK(test_config.dim.value() == 10000000000L); + } + SECTION("check invalid json values") { auto invalid_json_str = GENERATE(as{}, R"({