Skip to content

Commit d0997e0

Browse files
committed
fix: reject invalid HNSW binary metrics in static config check
Signed-off-by: xianliang.li <xianliang.li@zilliz.com>
1 parent 7041c3e commit d0997e0

2 files changed

Lines changed: 42 additions & 16 deletions

File tree

src/index/hnsw/faiss_hnsw.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2174,6 +2174,11 @@ class BaseFaissRegularIndexHNSWFlatNodeTemplate : public BaseFaissRegularIndexHN
21742174
StaticHasRawData(const knowhere::BaseConfig& config, const IndexVersion& version) {
21752175
return true;
21762176
}
2177+
2178+
static Status
2179+
StaticConfigCheck(const Config& cfg, PARAM_TYPE paramType, std::string& msg) {
2180+
return HnswIndexNode<DataType, hnswlib::QuantType::None>::StaticConfigCheck(cfg, paramType, msg);
2181+
}
21772182
};
21782183

21792184
// this is a regular node that can be initialized as some existing index type,
@@ -2494,6 +2499,11 @@ class BaseFaissRegularIndexHNSWFlatNodeTemplateWithSearchFallback : public HNSWI
24942499
return true;
24952500
}
24962501

2502+
static Status
2503+
StaticConfigCheck(const Config& cfg, PARAM_TYPE paramType, std::string& msg) {
2504+
return BaseFaissRegularIndexHNSWFlatNodeTemplate<DataType>::StaticConfigCheck(cfg, paramType, msg);
2505+
}
2506+
24972507
static std::unique_ptr<BaseConfig>
24982508
StaticCreateConfig() {
24992509
return std::make_unique<FaissHnswFlatConfig>();

tests/ut/test_config.cc

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,42 +28,44 @@
2828
#endif
2929

3030
void
31-
checkBuildConfig(knowhere::IndexType indexType, knowhere::Json& json) {
32-
std::string msg;
31+
checkBuildConfig(knowhere::IndexType indexType, const knowhere::Json& json) {
32+
const auto version = knowhere::Version::GetCurrentVersion().VersionNumber();
3333
if (knowhere::IndexFactory::Instance().FeatureCheck(indexType, knowhere::feature::BINARY)) {
34-
CHECK(knowhere::IndexStaticFaced<knowhere::bin1>::ConfigCheck(
35-
indexType, knowhere::Version::GetCurrentVersion().VersionNumber(), json, msg) ==
34+
auto binary_json = json;
35+
binary_json[knowhere::meta::METRIC_TYPE] = knowhere::metric::HAMMING;
36+
std::string msg;
37+
CHECK(knowhere::IndexStaticFaced<knowhere::bin1>::ConfigCheck(indexType, version, binary_json, msg) ==
3638
knowhere::Status::success);
3739
CHECK(msg.empty());
3840
}
3941
if (knowhere::IndexFactory::Instance().FeatureCheck(indexType, knowhere::feature::FLOAT32)) {
40-
CHECK(knowhere::IndexStaticFaced<float>::ConfigCheck(indexType,
41-
knowhere::Version::GetCurrentVersion().VersionNumber(),
42-
json, msg) == knowhere::Status::success);
42+
std::string msg;
43+
CHECK(knowhere::IndexStaticFaced<float>::ConfigCheck(indexType, version, json, msg) ==
44+
knowhere::Status::success);
4345
CHECK(msg.empty());
4446
}
4547
if (knowhere::IndexFactory::Instance().FeatureCheck(indexType, knowhere::feature::BF16)) {
46-
CHECK(knowhere::IndexStaticFaced<knowhere::bf16>::ConfigCheck(
47-
indexType, knowhere::Version::GetCurrentVersion().VersionNumber(), json, msg) ==
48+
std::string msg;
49+
CHECK(knowhere::IndexStaticFaced<knowhere::bf16>::ConfigCheck(indexType, version, json, msg) ==
4850
knowhere::Status::success);
4951
CHECK(msg.empty());
5052
}
5153
if (knowhere::IndexFactory::Instance().FeatureCheck(indexType, knowhere::feature::FP16)) {
52-
CHECK(knowhere::IndexStaticFaced<knowhere::fp16>::ConfigCheck(
53-
indexType, knowhere::Version::GetCurrentVersion().VersionNumber(), json, msg) ==
54+
std::string msg;
55+
CHECK(knowhere::IndexStaticFaced<knowhere::fp16>::ConfigCheck(indexType, version, json, msg) ==
5456
knowhere::Status::success);
5557
CHECK(msg.empty());
5658
}
5759
if (knowhere::IndexFactory::Instance().FeatureCheck(indexType, knowhere::feature::SPARSE_U32_F32)) {
58-
CHECK(knowhere::IndexStaticFaced<float>::ConfigCheck(indexType,
59-
knowhere::Version::GetCurrentVersion().VersionNumber(),
60-
json, msg) == knowhere::Status::success);
60+
std::string msg;
61+
CHECK(knowhere::IndexStaticFaced<float>::ConfigCheck(indexType, version, json, msg) ==
62+
knowhere::Status::success);
6163
CHECK(msg.empty());
6264
}
6365
#ifndef KNOWHERE_WITH_CARDINAL
6466
if (knowhere::IndexFactory::Instance().FeatureCheck(indexType, knowhere::feature::INT8)) {
65-
CHECK(knowhere::IndexStaticFaced<knowhere::int8>::ConfigCheck(
66-
indexType, knowhere::Version::GetCurrentVersion().VersionNumber(), json, msg) ==
67+
std::string msg;
68+
CHECK(knowhere::IndexStaticFaced<knowhere::int8>::ConfigCheck(indexType, version, json, msg) ==
6769
knowhere::Status::success);
6870
CHECK(msg.empty());
6971
}
@@ -514,6 +516,20 @@ TEST_CASE("Test config json parse", "[config]") {
514516
CHECK(msg.empty());
515517
}
516518

519+
// ---- HNSW (bin1): binary metric whitelist is enforced ----
520+
{
521+
knowhere::Json bad_json = knowhere::Json::parse(R"({
522+
"metric_type": "L2",
523+
"M": 16,
524+
"efConstruction": 96
525+
})");
526+
std::string msg;
527+
CHECK(knowhere::IndexStaticFaced<knowhere::bin1>::ConfigCheck(knowhere::IndexEnum::INDEX_HNSW, version,
528+
bad_json, msg) ==
529+
knowhere::Status::invalid_metric_type);
530+
CHECK_FALSE(msg.empty());
531+
}
532+
517533
// ---- IVF_FLAT (int8): int8 is mocked to fp32 at runtime, so float metrics must pass ----
518534
// (Without the companion fix to IvfIndexNode::StaticConfigCheck, dispatching to it for
519535
// int8 would wrongly fall into the binary branch and reject L2.)

0 commit comments

Comments
 (0)