diff --git a/.gitignore b/.gitignore index 0aae0406f..b89b4f834 100644 --- a/.gitignore +++ b/.gitignore @@ -79,3 +79,7 @@ graph_info.json # Claude Code plans (local only) docs/plans/ +docs/superpowers/ + +# Test artifacts (Catch2 tests write serialized indexes into CWD) +*.index diff --git a/include/knowhere/comp/index_param.h b/include/knowhere/comp/index_param.h index c5b9f73f1..fbf4bfb67 100644 --- a/include/knowhere/comp/index_param.h +++ b/include/knowhere/comp/index_param.h @@ -25,6 +25,7 @@ namespace IndexEnum { constexpr const char* INVALID = ""; constexpr const char* INDEX_FAISS_BIN_IDMAP = "BIN_FLAT"; +constexpr const char* INDEX_FAISS = "FAISS"; constexpr const char* INDEX_FAISS_BIN_IVFFLAT = "BIN_IVF_FLAT"; constexpr const char* INDEX_FAISS_IDMAP = "FLAT"; diff --git a/include/knowhere/config.h b/include/knowhere/config.h index 619630152..474c3bf4d 100644 --- a/include/knowhere/config.h +++ b/include/knowhere/config.h @@ -655,6 +655,15 @@ class BaseConfig : public Config { CFG_INT lemur_seed; // random seed for LEMUR CFG_INT lemur_num_layers; // number of layers in feature_extractor CFG_BOOL emb_list_rerank; // whether to perform MaxSim reranking after ANN search + + /// Optional hook: runs after FormatAndCheck and before Config::Load consumes typed + /// fields. Used by FaissConfig to capture the raw JSON verbatim for pass-through to + /// faiss's ParameterSpace. Default is a no-op; do NOT override unless you need raw + /// JSON (most configs should rely on KNOWHERE_CONFIG_DECLARE_FIELD). + virtual void + CaptureRawJson(const Json& /*json*/) { + } + KNOWHERE_DECLARE_CONFIG(BaseConfig) { KNOWHERE_CONFIG_DECLARE_FIELD(dim).allow_empty_without_default().description("vector dim").for_train(); KNOWHERE_CONFIG_DECLARE_FIELD(metric_type) diff --git a/include/knowhere/index/index_table.h b/include/knowhere/index/index_table.h index 7e1055689..d807305e9 100644 --- a/include/knowhere/index/index_table.h +++ b/include/knowhere/index/index_table.h @@ -28,6 +28,9 @@ static std::set> legal_knowhere_index = { {IndexEnum::INDEX_FAISS_IDMAP, VecType::VECTOR_BFLOAT16}, // {IndexEnum::INDEX_FAISS_IDMAP, VecType::VECTOR_INT8}, + {IndexEnum::INDEX_FAISS, VecType::VECTOR_FLOAT}, + {IndexEnum::INDEX_FAISS, VecType::VECTOR_BINARY}, + {IndexEnum::INDEX_FAISS_IVFFLAT, VecType::VECTOR_FLOAT}, {IndexEnum::INDEX_FAISS_IVFFLAT, VecType::VECTOR_FLOAT16}, {IndexEnum::INDEX_FAISS_IVFFLAT, VecType::VECTOR_BFLOAT16}, diff --git a/src/index/faiss/faiss.cc b/src/index/faiss/faiss.cc new file mode 100644 index 000000000..8564cc509 --- /dev/null +++ b/src/index/faiss/faiss.cc @@ -0,0 +1,431 @@ +// Copyright (C) 2019-2026 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file +// except in compliance with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the +// License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, +// either express or implied. See the License for the specific language governing permissions +// and limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "common/metric.h" +#include "folly/futures/Future.h" +#include "index/faiss/faiss_config.h" +#include "index/faiss/faiss_dispatch.h" +#include "knowhere/bitsetview_idselector.h" +#include "knowhere/comp/index_param.h" +#include "knowhere/comp/task.h" +#include "knowhere/context.h" +#include "knowhere/dataset.h" +#include "knowhere/feature.h" +#include "knowhere/index/index_factory.h" +#include "knowhere/index/index_node.h" +#include "knowhere/log.h" +#include "knowhere/operands.h" +#include "knowhere/range_util.h" +#include "knowhere/thread_pool.h" +#include "knowhere/utils.h" + +namespace knowhere { + +namespace { +// Backing faiss base type per DataType. +template +using FaissBase = std::conditional_t, faiss::Index, faiss::IndexBinary>; +} // namespace + +template +class FaissIndexNode : public IndexNode { + public: + static_assert(std::is_same_v || std::is_same_v, + "FaissIndexNode supports only fp32 and bin1"); + + FaissIndexNode(const int32_t version, const Object& /*object*/) : IndexNode(version) { + search_pool_ = ThreadPool::GetGlobalSearchThreadPool(); + } + + Status + Train(const DataSetPtr dataset, std::shared_ptr cfg, bool /*use_knowhere_build_pool*/) override { + const auto* fc = static_cast(cfg.get()); + const auto metric = Str2FaissMetricType(fc->metric_type.value()); + if (!metric.has_value()) { + return Status::invalid_metric_type; + } + is_cosine_ = IsMetricType(fc->metric_type.value(), knowhere::metric::COSINE); + + try { + if constexpr (std::is_same_v) { + index_.reset(::faiss::index_factory(static_cast(dataset->GetDim()), + fc->faiss_index_name.value().c_str(), metric.value())); + } else { + index_.reset(::faiss::index_binary_factory(static_cast(dataset->GetDim()), + fc->faiss_index_name.value().c_str())); + } + } catch (const ::faiss::FaissException& e) { + LOG_KNOWHERE_ERROR_ << "faiss::index_factory failed: " << e.what(); + return Status::invalid_args; + } + + std::string err; + auto st = faiss_vanilla::apply_build_params(index_.get(), fc->raw_params, &err); + if (st != Status::success) { + LOG_KNOWHERE_ERROR_ << err; + return st; + } + + try { + const auto* raw = dataset->GetTensor(); + const auto n = dataset->GetRows(); + if constexpr (std::is_same_v) { + auto data = static_cast(raw); + std::unique_ptr copy; + if (is_cosine_) { + copy = CopyAndNormalizeVecs(data, n, dataset->GetDim()); + data = copy.get(); + } + index_->train(n, data); + } else { + index_->train(n, static_cast(raw)); + } + } catch (const ::faiss::FaissException& e) { + LOG_KNOWHERE_ERROR_ << "faiss train failed: " << e.what(); + return Status::faiss_inner_error; + } + return Status::success; + } + + Status + Add(const DataSetPtr dataset, std::shared_ptr /*cfg*/, bool /*use_knowhere_build_pool*/) override { + if (!index_) { + return Status::empty_index; + } + try { + const auto* raw = dataset->GetTensor(); + const auto n = dataset->GetRows(); + if constexpr (std::is_same_v) { + auto data = static_cast(raw); + std::unique_ptr copy; + if (is_cosine_) { + copy = CopyAndNormalizeVecs(data, n, dataset->GetDim()); + data = copy.get(); + } + index_->add(n, data); + } else { + index_->add(n, static_cast(raw)); + } + } catch (const ::faiss::FaissException& e) { + LOG_KNOWHERE_ERROR_ << "faiss add failed: " << e.what(); + return Status::faiss_inner_error; + } + return Status::success; + } + + expected + Search(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset, + milvus::OpContext* op_context) const override { + if (!index_) { + return expected::Err(Status::empty_index, "index not loaded"); + } + const auto* fc = static_cast(cfg.get()); + const auto k = static_cast(fc->k.value()); + const auto nq = dataset->GetRows(); + const auto dim = dataset->GetDim(); + + BitsetViewIDSelector bw_sel(bitset); + ::faiss::IDSelector* sel = bitset.empty() ? nullptr : &bw_sel; + + std::unique_ptr<::faiss::SearchParameters> search_params; + std::string err_msg; + + auto ids = std::make_unique(nq * k); + auto distances = std::make_unique(nq * k); + + try { + if constexpr (std::is_same_v) { + Status st = faiss_vanilla::build_search_params(static_cast(index_.get()), + fc->raw_params, sel, &search_params, &err_msg); + if (st != Status::success) { + LOG_KNOWHERE_ERROR_ << err_msg; + return expected::Err(st, err_msg); + } + + const auto* raw = static_cast(dataset->GetTensor()); + std::unique_ptr norm_copy; + if (is_cosine_) { + norm_copy = CopyAndNormalizeVecs(raw, nq, dim); + raw = norm_copy.get(); + } + + std::vector> futs; + futs.reserve(nq); + for (int64_t i = 0; i < nq; ++i) { + futs.emplace_back(search_pool_->push([&, idx = i]() { + knowhere::checkCancellation(op_context); + ThreadPool::ScopedSearchOmpSetter setter(1); + index_->search(1, raw + idx * dim, k, distances.get() + idx * k, ids.get() + idx * k, + search_params.get()); + })); + } + WaitAllSuccess(futs); + } else { + Status st = faiss_vanilla::build_search_params(static_cast(index_.get()), + fc->raw_params, sel, &search_params, &err_msg); + if (st != Status::success) { + LOG_KNOWHERE_ERROR_ << err_msg; + return expected::Err(st, err_msg); + } + + const auto* raw = static_cast(dataset->GetTensor()); + // dim is in bits for binary indexes; bytes = dim / 8 + const auto bytes_per_vec = dim / 8; + // faiss binary search returns int32 distances; cast to float afterwards + auto int_distances = std::make_unique(nq * k); + + std::vector> futs; + futs.reserve(nq); + for (int64_t i = 0; i < nq; ++i) { + futs.emplace_back(search_pool_->push([&, idx = i]() { + knowhere::checkCancellation(op_context); + ThreadPool::ScopedSearchOmpSetter setter(1); + index_->search(1, raw + idx * bytes_per_vec, k, int_distances.get() + idx * k, + ids.get() + idx * k, search_params.get()); + })); + } + WaitAllSuccess(futs); + for (int64_t i = 0; i < nq * k; ++i) { + distances[i] = static_cast(int_distances[i]); + } + } + } catch (const ::faiss::FaissException& e) { + LOG_KNOWHERE_ERROR_ << "faiss search failed: " << e.what(); + return expected::Err(Status::faiss_inner_error, e.what()); + } + return GenResultDataSet(nq, k, std::move(ids), std::move(distances)); + } + + expected + RangeSearch(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset, + milvus::OpContext* op_context) const override { + if (!index_) { + return expected::Err(Status::empty_index, "index not loaded"); + } + + if constexpr (std::is_same_v) { + return expected::Err(Status::not_implemented, + "RangeSearch unsupported for binary faiss indexes"); + } else { + const auto* fc = static_cast(cfg.get()); + const float radius = fc->radius.value(); + const float range_filter = fc->range_filter.value(); + const auto nq = dataset->GetRows(); + const auto dim = dataset->GetDim(); + + BitsetViewIDSelector bw_sel(bitset); + ::faiss::IDSelector* sel = bitset.empty() ? nullptr : &bw_sel; + + std::unique_ptr<::faiss::SearchParameters> search_params; + std::string err_msg; + Status st = faiss_vanilla::build_search_params(static_cast(index_.get()), + fc->raw_params, sel, &search_params, &err_msg); + if (st != Status::success) { + LOG_KNOWHERE_ERROR_ << err_msg; + return expected::Err(st, err_msg); + } + + const auto* raw = static_cast(dataset->GetTensor()); + std::unique_ptr norm_copy; + if (is_cosine_) { + norm_copy = CopyAndNormalizeVecs(raw, nq, dim); + raw = norm_copy.get(); + } + + std::vector> result_distances(nq); + std::vector> result_labels(nq); + + try { + std::vector> futs; + futs.reserve(nq); + for (int64_t i = 0; i < nq; ++i) { + futs.emplace_back(search_pool_->push([&, idx = i]() { + knowhere::checkCancellation(op_context); + ThreadPool::ScopedSearchOmpSetter setter(1); + ::faiss::RangeSearchResult r(1); + index_->range_search(1, raw + idx * dim, radius, &r, search_params.get()); + const size_t cnt = r.lims[1]; + result_distances[idx].assign(r.distances, r.distances + cnt); + result_labels[idx].assign(r.labels, r.labels + cnt); + })); + } + WaitAllSuccess(futs); + } catch (const ::faiss::FaissException& e) { + const std::string msg = std::string("faiss range_search failed: ") + e.what() + + ". Please check if the corresponding faiss index has " + "implemented interface"; + LOG_KNOWHERE_ERROR_ << msg; + return expected::Err(Status::faiss_inner_error, msg); + } + + const bool is_ip = is_cosine_ || IsMetricType(fc->metric_type.value(), knowhere::metric::IP); + auto rr = GetRangeSearchResult(result_distances, result_labels, is_ip, nq, radius, range_filter); + return GenResultDataSet(nq, std::move(rr)); + } + } + + // Vanilla faiss adapter does not expose raw vectors: faiss::reconstruct() + // is lossy on quantized indexes (PQ, SQ, ...) and unsupported on some others. + // Callers needing raw data should use a dedicated index type. + expected + GetVectorByIds(const DataSetPtr /*dataset*/, milvus::OpContext* /*op_context*/) const override { + return expected::Err(Status::not_implemented, + "GetVectorByIds not supported by vanilla faiss adapter"); + } + + bool + HasRawData(const std::string& /*metric_type*/) const override { + return false; + } + + expected + GetIndexMeta(std::unique_ptr /*cfg*/) const override { + return expected::Err(Status::not_implemented, "GetIndexMeta not supported"); + } + + Status + Serialize(BinarySet& binset) const override { + if (!index_) { + return Status::empty_index; + } + try { + ::faiss::VectorIOWriter writer; + if constexpr (std::is_same_v) { + ::faiss::write_index(index_.get(), &writer); + } else { + ::faiss::write_index_binary(index_.get(), &writer); + } + auto sz = writer.data.size(); + std::shared_ptr buf(new uint8_t[sz]); + std::memcpy(buf.get(), writer.data.data(), sz); + binset.Append(Type(), buf, static_cast(sz)); + return Status::success; + } catch (const ::faiss::FaissException& e) { + LOG_KNOWHERE_ERROR_ << "Serialize failed: " << e.what(); + return Status::faiss_inner_error; + } + } + + Status + Deserialize(const BinarySet& binset, std::shared_ptr /*config*/) override { + auto bin = binset.GetByName(Type()); + if (bin == nullptr) { + return Status::invalid_binary_set; + } + try { + ::faiss::VectorIOReader reader; + reader.data.assign(bin->data.get(), bin->data.get() + bin->size); + if constexpr (std::is_same_v) { + index_.reset(::faiss::read_index(&reader)); + } else { + index_.reset(::faiss::read_index_binary(&reader)); + } + return Status::success; + } catch (const ::faiss::FaissException& e) { + LOG_KNOWHERE_ERROR_ << "Deserialize failed: " << e.what(); + return Status::faiss_inner_error; + } + } + + Status + DeserializeFromFile(const std::string& filename, std::shared_ptr config) override { + const auto* fc = static_cast(config.get()); + const bool use_mmap = fc->enable_mmap.value_or(false); + try { + if constexpr (std::is_same_v) { + if (use_mmap) { + auto owner = std::make_shared(filename.data()); + faiss::MappedFileIOReader reader(owner); + index_.reset(faiss::read_index(&reader)); + } else { + faiss::FileIOReader reader(filename.data()); + index_.reset(faiss::read_index(&reader)); + } + } else { + if (use_mmap) { + auto owner = std::make_shared(filename.data()); + faiss::MappedFileIOReader reader(owner); + index_.reset(faiss::read_index_binary(&reader)); + } else { + faiss::FileIOReader reader(filename.data()); + index_.reset(faiss::read_index_binary(&reader)); + } + } + return Status::success; + } catch (const ::faiss::FaissException& e) { + LOG_KNOWHERE_ERROR_ << "DeserializeFromFile failed: " << e.what(); + return Status::faiss_inner_error; + } + } + + static std::unique_ptr + StaticCreateConfig() { + return std::make_unique(); + } + + std::unique_ptr + CreateConfig() const override { + return StaticCreateConfig(); + } + + int64_t + Dim() const override { + return index_ ? index_->d : 0; + } + + int64_t + Size() const override { + if (!index_) { + return 0; + } + faiss::cppcontrib::knowhere::CountSizeIOWriter writer; + if constexpr (std::is_same_v) { + faiss::write_index(index_.get(), &writer); + } else { + faiss::write_index_binary(index_.get(), &writer); + } + return static_cast(writer.total_size); + } + + int64_t + Count() const override { + return index_ ? index_->ntotal : 0; + } + + std::string + Type() const override { + return knowhere::IndexEnum::INDEX_FAISS; + } + + protected: + std::unique_ptr> index_; + std::shared_ptr search_pool_; + bool is_cosine_{false}; +}; + +KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS, FaissIndexNode, fp32, knowhere::feature::MMAP | knowhere::feature::FLOAT32); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS, FaissIndexNode, bin1, knowhere::feature::MMAP | knowhere::feature::BINARY); + +} // namespace knowhere diff --git a/src/index/faiss/faiss_config.h b/src/index/faiss/faiss_config.h new file mode 100644 index 000000000..11a356518 --- /dev/null +++ b/src/index/faiss/faiss_config.h @@ -0,0 +1,57 @@ +// Copyright (C) 2019-2026 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this +// file except in compliance with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under +// the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +// ANY KIND, either express or implied. See the License for the specific language +// governing permissions and limitations under the License. + +#pragma once + +#include "knowhere/config.h" + +namespace knowhere { + +class FaissConfig : public BaseConfig { + public: + // Required. faiss DSL understood by faiss::index_factory (fp32) or + // faiss::index_binary_factory (bin1). Examples: "Flat", "IVF1024,PQ16x8", + // "HNSW32,Flat", "BIVF256,Hamming". + CFG_STRING faiss_index_name; + + // Captured subset of the incoming JSON: only keys that this config's __DICT__ + // does NOT declare (i.e. not owned by Knowhere's native config layer). Those are + // the keys the vanilla faiss adapter forwards to faiss::ParameterSpace + // (build) and per-family SearchParametersXxx (search). Declared keys (k, + // metric_type, trace_id, faiss_index_name, ...) are consumed by Config::Load + // into typed fields and therefore filtered out of raw_params at capture time. + Json raw_params; + + KNOWHERE_DECLARE_CONFIG(FaissConfig) { + KNOWHERE_CONFIG_DECLARE_FIELD(faiss_index_name) + .description("faiss factory string, e.g. \"IVF1024,PQ16x8\"") + .allow_empty_without_default() + .for_train() + .for_deserialize() + .for_deserialize_from_file(); + } + + void + CaptureRawJson(const Json& json) override { + raw_params = Json::object(); + for (auto it = json.begin(); it != json.end(); ++it) { + // Skip any key already declared as a typed field on BaseConfig or + // FaissConfig — those are Knowhere's own and will be consumed by + // Config::Load. Everything else is a faiss-bound knob we forward. + if (__DICT__.count(it.key()) == 0) { + raw_params[it.key()] = it.value(); + } + } + } +}; + +} // namespace knowhere diff --git a/src/index/faiss/faiss_dispatch.cc b/src/index/faiss/faiss_dispatch.cc new file mode 100644 index 000000000..5f1f5218e --- /dev/null +++ b/src/index/faiss/faiss_dispatch.cc @@ -0,0 +1,171 @@ +// Copyright (C) 2019-2026 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this +// file except in compliance with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under +// the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +// ANY KIND, either express or implied. See the License for the specific language +// governing permissions and limitations under the License. + +#include "index/faiss/faiss_dispatch.h" + +#include +#include +#include +#include +#include +#include + +namespace knowhere::faiss_vanilla { + +namespace { + +// Coerce a json value into a double for faiss consumption. Accepts: +// - numbers: e.g. 16, 16.0 -> 16.0 +// - booleans: true / false -> 1.0 / 0.0 +// - stringified numbers: "16" -> 16.0 +// - stringified booleans: "true" -> 1.0 +// Rejects arrays, objects, null, and unparseable strings. Matches the spirit of +// knowhere::Config::FormatAndCheck's string-to-typed coercion for declared fields, +// so forwarded keys behave consistently with native Knowhere keys. +Status +coerce_to_double(const Json& v, const std::string& key, double* out, std::string* err_msg) { + if (v.is_number()) { + *out = v.get(); + return Status::success; + } + if (v.is_boolean()) { + *out = v.get() ? 1.0 : 0.0; + return Status::success; + } + if (v.is_string()) { + const std::string s = v.get(); + if (s == "true") { + *out = 1.0; + return Status::success; + } + if (s == "false") { + *out = 0.0; + return Status::success; + } + try { + size_t pos = 0; + double parsed = std::stod(s, &pos); + if (pos == s.size()) { + *out = parsed; + return Status::success; + } + } catch (const std::invalid_argument&) { + } catch (const std::out_of_range&) { + } + } + if (err_msg) { + *err_msg = "faiss vanilla: param '" + key + "' expects a number or boolean; got " + v.dump(); + } + return Status::invalid_args; +} + +// Apply every key in raw_params to the faiss index. raw_params has already been +// filtered by FaissConfig::CaptureRawJson to exclude keys owned by Knowhere's own +// config layer (fields declared via KNOWHERE_CONFIG_DECLARE_FIELD). We pre-validate +// the remaining keys against the faiss-owned whitelist (supported_build_param_names +// + "quantizer_*" prefix handling) before calling ParameterSpace. A key that fails +// the whitelist (typo, non-faiss param) is rejected with a clear error; a key that +// passes the whitelist but is incompatible with the concrete index type (e.g. +// nprobe on an HNSW) is still caught by ParameterSpace's exception and surfaced +// as invalid_args. +template +Status +apply_impl(IndexT* index, const Json& raw_params, std::string* err_msg) { + ::faiss::ParameterSpace ps; + for (auto it = raw_params.begin(); it != raw_params.end(); ++it) { + const std::string& key = it.key(); + if (!::faiss::cppcontrib::knowhere::is_supported_build_param(key)) { + if (err_msg) { + *err_msg = "faiss vanilla: build param '" + key + "' is not recognized"; + } + return Status::invalid_args; + } + double val = 0.0; + auto cst = coerce_to_double(it.value(), key, &val, err_msg); + if (cst != Status::success) { + return cst; + } + try { + ps.set_index_parameter(index, key, val); + } catch (const ::faiss::FaissException& e) { + if (err_msg) { + *err_msg = std::string("faiss rejected param '") + key + "': " + e.what(); + } + return Status::invalid_args; + } + } + return Status::success; +} + +// Shared logic for search-param builders. `index` can be faiss::Index* or IndexBinary*. +// raw_params has already been filtered by FaissConfig::CaptureRawJson to contain only +// keys NOT declared by Knowhere's typed config. Uses the faiss-owned whitelist +// (supported_search_params) to validate remaining keys, and delegates both the +// SearchParameters-family selection and the per-name field set to the upstream +// helper. Knowhere layer only adds: (1) sel attach, (2) JSON->double conversion, +// (3) clear error wording. +template +Status +build_search_params_impl(const IndexT* index, const Json& raw_params, ::faiss::IDSelector* sel, + std::unique_ptr<::faiss::SearchParameters>* out, std::string* err_msg) { + auto params = ::faiss::cppcontrib::knowhere::make_search_params(index); + params->sel = sel; + + const auto supported = ::faiss::cppcontrib::knowhere::supported_search_params(index); + for (auto it = raw_params.begin(); it != raw_params.end(); ++it) { + const std::string& key = it.key(); + if (!supported.count(key)) { + if (err_msg) { + *err_msg = "faiss vanilla: search param '" + key + "' not supported for this index family"; + } + return Status::invalid_args; + } + double val = 0.0; + auto cst = coerce_to_double(it.value(), key, &val, err_msg); + if (cst != Status::success) { + return cst; + } + // Whitelist already guarantees try_set_search_param returns true; treat a + // false here as an invariant breach rather than user error. + (void)::faiss::cppcontrib::knowhere::try_set_search_param(params.get(), key, val); + } + *out = std::move(params); + return Status::success; +} + +} // namespace + +Status +apply_build_params(::faiss::Index* index, const Json& raw_params, std::string* err_msg) { + return apply_impl(index, raw_params, err_msg); +} + +Status +apply_build_params(::faiss::IndexBinary* index, const Json& raw_params, std::string* err_msg) { + return apply_impl(index, raw_params, err_msg); +} + +Status +build_search_params(const ::faiss::Index* index, const Json& raw_params, ::faiss::IDSelector* sel, + std::unique_ptr<::faiss::SearchParameters>* out, std::string* err_msg) { + return build_search_params_impl(index, raw_params, sel, out, err_msg); +} + +Status +build_search_params(const ::faiss::IndexBinary* index, const Json& raw_params, ::faiss::IDSelector* sel, + std::unique_ptr<::faiss::SearchParameters>* out, std::string* err_msg) { + // IndexBinaryIVF requires SearchParametersIVF; binary side also does not honor + // IDSelector, so attaching sel here is typically a no-op at search time. + return build_search_params_impl(index, raw_params, sel, out, err_msg); +} + +} // namespace knowhere::faiss_vanilla diff --git a/src/index/faiss/faiss_dispatch.h b/src/index/faiss/faiss_dispatch.h new file mode 100644 index 000000000..856922c94 --- /dev/null +++ b/src/index/faiss/faiss_dispatch.h @@ -0,0 +1,50 @@ +// Copyright (C) 2019-2026 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this +// file except in compliance with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under +// the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +// ANY KIND, either express or implied. See the License for the specific language +// governing permissions and limitations under the License. + +#pragma once + +#include +#include + +#include "knowhere/config.h" + +namespace faiss { +struct Index; +struct IndexBinary; +struct IDSelector; +struct SearchParameters; +} // namespace faiss + +namespace knowhere::faiss_vanilla { + +// Forwards keys from raw_params to faiss::ParameterSpace::set_index_parameter +// on the given index. Converts faiss exceptions into Status::invalid_args with the +// faiss message in *err_msg. +Status +apply_build_params(::faiss::Index* index, const Json& raw_params, std::string* err_msg); + +Status +apply_build_params(::faiss::IndexBinary* index, const Json& raw_params, std::string* err_msg); + +// Build a per-request SearchParameters* appropriate for the concrete faiss index +// family. The family dispatch itself lives in faiss::cppcontrib::knowhere (upstream- +// bound helper); this wrapper adds: (1) sel assignment, (2) framework-key filtering, +// (3) JSON value extraction + unknown-key error surfacing. +Status +build_search_params(const ::faiss::Index* index, const Json& raw_params, ::faiss::IDSelector* sel, + std::unique_ptr<::faiss::SearchParameters>* out, std::string* err_msg); + +Status +build_search_params(const ::faiss::IndexBinary* index, const Json& raw_params, ::faiss::IDSelector* sel, + std::unique_ptr<::faiss::SearchParameters>* out, std::string* err_msg); + +} // namespace knowhere::faiss_vanilla diff --git a/src/index/index.cc b/src/index/index.cc index b2af8cfb2..336581c9c 100644 --- a/src/index/index.cc +++ b/src/index/index.cc @@ -33,6 +33,7 @@ LoadConfig(BaseConfig* cfg, const Json& json, knowhere::PARAM_TYPE param_type, c auto res = Config::FormatAndCheck(*cfg, json_, msg); LOG_KNOWHERE_DEBUG_ << method << " config dump: " << json_.dump(); RETURN_IF_ERROR(res); + cfg->CaptureRawJson(json_); return Config::Load(*cfg, json_, param_type, msg); } diff --git a/tests/ut/test_faiss_vanilla.cc b/tests/ut/test_faiss_vanilla.cc new file mode 100644 index 000000000..36cfb8540 --- /dev/null +++ b/tests/ut/test_faiss_vanilla.cc @@ -0,0 +1,448 @@ +// Copyright (C) 2019-2026 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this +// file except in compliance with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 + +#include +#include + +#include "catch2/catch_test_macros.hpp" +#include "index/faiss/faiss_config.h" +#include "knowhere/bitsetview.h" +#include "knowhere/comp/index_param.h" +#include "knowhere/config.h" +#include "knowhere/dataset.h" +#include "knowhere/index/index_factory.h" + +namespace { +knowhere::DataSetPtr +gen_fp32(size_t nb, size_t dim, int64_t seed = 42) { + auto* xb = new float[nb * dim]; + std::mt19937 gen(seed); + std::normal_distribution dist(0.0f, 1.0f); + for (size_t i = 0; i < nb * dim; ++i) xb[i] = dist(gen); + auto ds = knowhere::GenDataSet(nb, dim, xb); + ds->SetIsOwner(true); + return ds; +} + +knowhere::DataSetPtr +gen_bin(size_t nb, size_t dim_bits, uint64_t seed = 42) { + const size_t bytes = (dim_bits + 7) / 8; + auto* xb = new uint8_t[nb * bytes]; + std::mt19937_64 rng(seed); + for (size_t i = 0; i < nb * bytes; ++i) xb[i] = static_cast(rng()); + auto ds = knowhere::GenDataSet(nb, dim_bits, xb); + ds->SetIsOwner(true); + return ds; +} +} // namespace + +TEST_CASE("FaissConfig parses faiss_index_name and captures raw JSON", "[faiss_vanilla]") { + knowhere::FaissConfig cfg; + knowhere::Json j = + knowhere::Json::parse(R"({"metric_type":"L2","faiss_index_name":"IVF256,Flat","nprobe":16,"efSearch":32})"); + std::string msg; + + // Replicate LoadConfig's internal sequence using only public-header entry points. + knowhere::Json j_(j); + REQUIRE(knowhere::Config::FormatAndCheck(cfg, j_, &msg) == knowhere::Status::success); + cfg.CaptureRawJson(j_); + REQUIRE(knowhere::Config::Load(cfg, j_, knowhere::TRAIN, &msg) == knowhere::Status::success); + + REQUIRE(cfg.faiss_index_name.value() == "IVF256,Flat"); + REQUIRE(cfg.raw_params.contains("nprobe")); + REQUIRE(cfg.raw_params["nprobe"] == 16); + REQUIRE(cfg.raw_params.contains("efSearch")); +} + +TEST_CASE("IndexFactory creates FAISS index for fp32", "[faiss_vanilla]") { + auto version = knowhere::Version::GetCurrentVersion(); + auto idx = knowhere::IndexFactory::Instance().Create(knowhere::IndexEnum::INDEX_FAISS, + version.VersionNumber()); + REQUIRE(idx.has_value()); + REQUIRE(idx.value().Type() == knowhere::IndexEnum::INDEX_FAISS); +} + +TEST_CASE("FAISS Train+Add Flat smoke", "[faiss_vanilla]") { + const size_t nb = 1000, dim = 16; + auto version = knowhere::Version::GetCurrentVersion(); + auto idx = knowhere::IndexFactory::Instance() + .Create(knowhere::IndexEnum::INDEX_FAISS, version.VersionNumber()) + .value(); + knowhere::Json j = knowhere::Json::parse(R"({"metric_type":"L2","faiss_index_name":"Flat","dim":16})"); + auto ds = gen_fp32(nb, dim); + REQUIRE(idx.Build(ds, j) == knowhere::Status::success); + REQUIRE(idx.Count() == static_cast(nb)); + REQUIRE(idx.Dim() == static_cast(dim)); +} + +TEST_CASE("FAISS Train forwards parameters via ParameterSpace", "[faiss_vanilla]") { + const size_t nb = 2000, dim = 32; + auto version = knowhere::Version::GetCurrentVersion(); + auto idx = knowhere::IndexFactory::Instance() + .Create(knowhere::IndexEnum::INDEX_FAISS, version.VersionNumber()) + .value(); + // nprobe is a search-time knob on IVF but ParameterSpace will accept it at build + // time by setting the field directly. This verifies the forwarding plumbing. + knowhere::Json j = + knowhere::Json::parse(R"({"metric_type":"L2","faiss_index_name":"IVF64,Flat","dim":32,"nprobe":8})"); + auto ds = gen_fp32(nb, dim); + REQUIRE(idx.Build(ds, j) == knowhere::Status::success); +} + +TEST_CASE("FAISS Search on Flat returns exact KNN", "[faiss_vanilla]") { + const size_t nb = 500, dim = 8, nq = 3, k = 5; + auto version = knowhere::Version::GetCurrentVersion(); + auto idx = knowhere::IndexFactory::Instance() + .Create(knowhere::IndexEnum::INDEX_FAISS, version.VersionNumber()) + .value(); + knowhere::Json build = knowhere::Json::parse(R"({"metric_type":"L2","faiss_index_name":"Flat","dim":8})"); + auto base = gen_fp32(nb, dim); + REQUIRE(idx.Build(base, build) == knowhere::Status::success); + + knowhere::Json search = knowhere::Json::parse(R"({"metric_type":"L2","faiss_index_name":"Flat","k":5})"); + auto queries = gen_fp32(nq, dim, /*seed=*/7); + auto res = idx.Search(queries, search, nullptr); + REQUIRE(res.has_value()); + REQUIRE(res.value()->GetRows() == static_cast(nq)); + const auto* ids = res.value()->GetIds(); + for (size_t q = 0; q < nq; ++q) { + for (size_t j = 0; j < k; ++j) { + REQUIRE(ids[q * k + j] >= 0); + REQUIRE(ids[q * k + j] < static_cast(nb)); + } + } +} + +TEST_CASE("FAISS Search accepts nprobe on IVF via SearchParametersIVF", "[faiss_vanilla]") { + const size_t nb = 2000, dim = 16, nq = 4; + auto version = knowhere::Version::GetCurrentVersion(); + auto idx = knowhere::IndexFactory::Instance() + .Create(knowhere::IndexEnum::INDEX_FAISS, version.VersionNumber()) + .value(); + knowhere::Json build = knowhere::Json::parse(R"({"metric_type":"L2","faiss_index_name":"IVF64,Flat","dim":16})"); + REQUIRE(idx.Build(gen_fp32(nb, dim), build) == knowhere::Status::success); + + knowhere::Json search = + knowhere::Json::parse(R"({"metric_type":"L2","faiss_index_name":"IVF64,Flat","k":10,"nprobe":8})"); + auto res = idx.Search(gen_fp32(nq, dim, 99), search, nullptr); + REQUIRE(res.has_value()); +} + +TEST_CASE("FAISS Search honors BitsetView filter", "[faiss_vanilla]") { + const size_t nb = 200, dim = 8, nq = 1, k = 10; + auto version = knowhere::Version::GetCurrentVersion(); + auto idx = knowhere::IndexFactory::Instance() + .Create(knowhere::IndexEnum::INDEX_FAISS, version.VersionNumber()) + .value(); + knowhere::Json j = knowhere::Json::parse(R"({"metric_type":"L2","faiss_index_name":"Flat","dim":8,"k":10})"); + REQUIRE(idx.Build(gen_fp32(nb, dim), j) == knowhere::Status::success); + + // Filter out ids [0, 50) — set those bits to 1 (filtered) + std::vector bits((nb + 7) / 8, 0); + for (size_t i = 0; i < 50; ++i) bits[i / 8] |= (1 << (i % 8)); + knowhere::BitsetView bitset(bits.data(), nb); + + auto res = idx.Search(gen_fp32(nq, dim, 3), j, bitset); + REQUIRE(res.has_value()); + const auto* ids = res.value()->GetIds(); + for (size_t i = 0; i < nq * k; ++i) { + REQUIRE(ids[i] >= 50); // any id < 50 would mean filtering is broken + } +} + +TEST_CASE("FAISS RangeSearch supported on Flat", "[faiss_vanilla]") { + const size_t nb = 100, dim = 8; + auto version = knowhere::Version::GetCurrentVersion(); + auto idx = knowhere::IndexFactory::Instance() + .Create(knowhere::IndexEnum::INDEX_FAISS, version.VersionNumber()) + .value(); + knowhere::Json build = knowhere::Json::parse(R"({"metric_type":"L2","faiss_index_name":"Flat","dim":8})"); + REQUIRE(idx.Build(gen_fp32(nb, dim), build) == knowhere::Status::success); + + knowhere::Json search = + knowhere::Json::parse(R"({"metric_type":"L2","faiss_index_name":"Flat","radius":100.0,"range_filter":0.0})"); + auto query = gen_fp32(1, dim, 55); + auto res = idx.RangeSearch(query, search, nullptr); + REQUIRE(res.has_value()); +} + +TEST_CASE("FAISS HasRawData/GetVectorByIds unsupported by vanilla adapter", "[faiss_vanilla]") { + auto version = knowhere::Version::GetCurrentVersion(); + + auto check_unsupported = [&](const std::string& factory_str, size_t nb, size_t dim) { + auto idx = knowhere::IndexFactory::Instance() + .Create(knowhere::IndexEnum::INDEX_FAISS, version.VersionNumber()) + .value(); + knowhere::Json j = knowhere::Json::parse(R"({"metric_type":"L2","faiss_index_name":")" + factory_str + + R"(","dim":)" + std::to_string(dim) + "}"); + REQUIRE(idx.Build(gen_fp32(nb, dim), j) == knowhere::Status::success); + + REQUIRE(idx.HasRawData("L2") == false); + + int64_t query_id = 5; + auto ids_ds = knowhere::GenIdsDataSet(1, &query_id); + auto r = idx.GetVectorByIds(ids_ds); + REQUIRE_FALSE(r.has_value()); + REQUIRE(r.error() == knowhere::Status::not_implemented); + }; + + check_unsupported("Flat", 64, 8); + check_unsupported("IVF64,Flat", 256, 8); +} + +TEST_CASE("FAISS Serialize/Deserialize roundtrip", "[faiss_vanilla]") { + const size_t nb = 200, dim = 8; + auto version = knowhere::Version::GetCurrentVersion(); + + auto idx1 = knowhere::IndexFactory::Instance() + .Create(knowhere::IndexEnum::INDEX_FAISS, version.VersionNumber()) + .value(); + knowhere::Json j = knowhere::Json::parse(R"({"metric_type":"L2","faiss_index_name":"Flat","dim":8,"k":3})"); + REQUIRE(idx1.Build(gen_fp32(nb, dim), j) == knowhere::Status::success); + + knowhere::BinarySet bs; + REQUIRE(idx1.Serialize(bs) == knowhere::Status::success); + + auto idx2 = knowhere::IndexFactory::Instance() + .Create(knowhere::IndexEnum::INDEX_FAISS, version.VersionNumber()) + .value(); + REQUIRE(idx2.Deserialize(bs, j) == knowhere::Status::success); + REQUIRE(idx2.Count() == static_cast(nb)); + REQUIRE(idx2.Dim() == static_cast(dim)); + + // Both indexes must produce identical KNN for the same query. + auto q = gen_fp32(1, dim, 777); + auto r1 = idx1.Search(q, j, nullptr).value(); + auto r2 = idx2.Search(q, j, nullptr).value(); + for (int64_t i = 0; i < 3; ++i) { + REQUIRE(r1->GetIds()[i] == r2->GetIds()[i]); + } +} + +// --------------------------------------------------------------------------- +// Task 10: Binary path end-to-end test +// --------------------------------------------------------------------------- + +TEST_CASE("FAISS binary: BFlat build + search", "[faiss_vanilla]") { + // Use BFlat (brute-force binary) rather than BIVF for the smoke test: + // - Exercises the bin1 IndexNode path end-to-end (index_binary_factory, + // write_index_binary / read_index_binary, binary search with int32 + // distance → float projection). + // - Avoids IndexBinaryIVF::train → Clustering::train_encoded → + // IndexLSH::sa_decode, an upstream faiss path where ASAN flags a + // heap-use-after-free under the cross-test malloc reuse pattern of the + // knowhere UT binary. That's an upstream bug unrelated to this adapter. + const size_t nb = 1024, dim_bits = 64, nq = 2; + auto version = knowhere::Version::GetCurrentVersion(); + auto idx = knowhere::IndexFactory::Instance() + .Create(knowhere::IndexEnum::INDEX_FAISS, version.VersionNumber()) + .value(); + knowhere::Json j = knowhere::Json::parse(R"({"metric_type":"HAMMING","faiss_index_name":"BFlat","dim":64,"k":5})"); + REQUIRE(idx.Build(gen_bin(nb, dim_bits), j) == knowhere::Status::success); + REQUIRE(idx.Count() == static_cast(nb)); + auto res = idx.Search(gen_bin(nq, dim_bits, 3), j, nullptr); + REQUIRE(res.has_value()); +} + +// --------------------------------------------------------------------------- +// Task 11: Error-case tests +// --------------------------------------------------------------------------- + +TEST_CASE("FAISS: invalid faiss_index_name returns invalid_args", "[faiss_vanilla]") { + auto version = knowhere::Version::GetCurrentVersion(); + auto idx = knowhere::IndexFactory::Instance() + .Create(knowhere::IndexEnum::INDEX_FAISS, version.VersionNumber()) + .value(); + knowhere::Json j = + knowhere::Json::parse(R"({"metric_type":"L2","faiss_index_name":"NotARealFactoryString","dim":8})"); + auto st = idx.Build(gen_fp32(32, 8), j); + REQUIRE(st == knowhere::Status::invalid_args); +} + +TEST_CASE("FAISS: typo key surfaces faiss error at build", "[faiss_vanilla]") { + auto version = knowhere::Version::GetCurrentVersion(); + auto idx = knowhere::IndexFactory::Instance() + .Create(knowhere::IndexEnum::INDEX_FAISS, version.VersionNumber()) + .value(); + // "n_probe" (with underscore) is wrong; the real key is "nprobe". + // faiss::ParameterSpace::set_index_parameter throws on unknown knobs. + // The adapter translates that to invalid_args. + knowhere::Json j = + knowhere::Json::parse(R"({"metric_type":"L2","faiss_index_name":"IVF32,Flat","dim":8,"n_probe":4})"); + auto st = idx.Build(gen_fp32(64, 8), j); + REQUIRE(st == knowhere::Status::invalid_args); +} + +TEST_CASE("FAISS: search key unknown to family returns invalid_args", "[faiss_vanilla]") { + auto version = knowhere::Version::GetCurrentVersion(); + auto idx = knowhere::IndexFactory::Instance() + .Create(knowhere::IndexEnum::INDEX_FAISS, version.VersionNumber()) + .value(); + knowhere::Json jb = knowhere::Json::parse(R"({"metric_type":"L2","faiss_index_name":"Flat","dim":8})"); + REQUIRE(idx.Build(gen_fp32(64, 8), jb) == knowhere::Status::success); + // efSearch is an HNSW knob; Flat uses base SearchParameters and does not accept it. + knowhere::Json jq = knowhere::Json::parse(R"({"metric_type":"L2","faiss_index_name":"Flat","k":3,"efSearch":32})"); + auto res = idx.Search(gen_fp32(1, 8), jq, nullptr); + REQUIRE_FALSE(res.has_value()); +} + +// --------------------------------------------------------------------------- +// Task 12: Size() memory estimate +// --------------------------------------------------------------------------- + +TEST_CASE("FAISS Size() gives a non-zero estimate after Build", "[faiss_vanilla]") { + auto version = knowhere::Version::GetCurrentVersion(); + auto idx = knowhere::IndexFactory::Instance() + .Create(knowhere::IndexEnum::INDEX_FAISS, version.VersionNumber()) + .value(); + knowhere::Json j = knowhere::Json::parse(R"({"metric_type":"L2","faiss_index_name":"Flat","dim":8})"); + REQUIRE(idx.Build(gen_fp32(100, 8), j) == knowhere::Status::success); + REQUIRE(idx.Size() > 0); +} + +// --------------------------------------------------------------------------- +// Task 13: Concurrent search isolation +// --------------------------------------------------------------------------- + +TEST_CASE("FAISS: concurrent searches with varying nprobe are isolated", "[faiss_vanilla]") { + const size_t nb = 2000, dim = 16; + auto version = knowhere::Version::GetCurrentVersion(); + auto idx = knowhere::IndexFactory::Instance() + .Create(knowhere::IndexEnum::INDEX_FAISS, version.VersionNumber()) + .value(); + knowhere::Json jb = knowhere::Json::parse(R"({"metric_type":"L2","faiss_index_name":"IVF64,Flat","dim":16})"); + REQUIRE(idx.Build(gen_fp32(nb, dim), jb) == knowhere::Status::success); + + auto worker = [&](int nprobe) { + for (int i = 0; i < 20; ++i) { + knowhere::Json jq = knowhere::Json::parse(R"({"metric_type":"L2","faiss_index_name":"IVF64,Flat","k":5})"); + jq["nprobe"] = nprobe; + auto res = idx.Search(gen_fp32(1, dim, nprobe * 100 + i), jq, nullptr); + REQUIRE(res.has_value()); + } + }; + std::thread t1(worker, 4); + std::thread t2(worker, 32); + t1.join(); + t2.join(); +} + +// PreTransform wrapper: OPQ16,IVF64,PQ16x4 — outer is IndexPreTransform, inner IVFPQ. +// Verifies build_search_params recurses through PreTransform and forwards nprobe to the +// inner IVF SearchParameters. +TEST_CASE("FAISS PreTransform: nprobe propagates through OPQ to IVFPQ", "[faiss_vanilla]") { + const size_t nb = 4096, dim = 16, nq = 4; + auto version = knowhere::Version::GetCurrentVersion(); + auto idx = knowhere::IndexFactory::Instance() + .Create(knowhere::IndexEnum::INDEX_FAISS, version.VersionNumber()) + .value(); + knowhere::Json build = + knowhere::Json::parse(R"({"metric_type":"L2","faiss_index_name":"OPQ16,IVF64,PQ16x4","dim":16})"); + REQUIRE(idx.Build(gen_fp32(nb, dim), build) == knowhere::Status::success); + + knowhere::Json search = + knowhere::Json::parse(R"({"metric_type":"L2","faiss_index_name":"OPQ16,IVF64,PQ16x4","k":5,"nprobe":8})"); + auto res = idx.Search(gen_fp32(nq, dim, 11), search, nullptr); + REQUIRE(res.has_value()); +} + +// Refine wrapper: IVF64,PQ8x4,RFlat. Verify k_factor is consumed at the wrapper layer +// and nprobe is forwarded to the base IVF. +TEST_CASE("FAISS Refine: k_factor + base nprobe both honored", "[faiss_vanilla]") { + const size_t nb = 4096, dim = 16, nq = 4; + auto version = knowhere::Version::GetCurrentVersion(); + auto idx = knowhere::IndexFactory::Instance() + .Create(knowhere::IndexEnum::INDEX_FAISS, version.VersionNumber()) + .value(); + knowhere::Json build = + knowhere::Json::parse(R"({"metric_type":"L2","faiss_index_name":"IVF64,PQ8x4,RFlat","dim":16})"); + REQUIRE(idx.Build(gen_fp32(nb, dim), build) == knowhere::Status::success); + + knowhere::Json search = knowhere::Json::parse( + R"({"metric_type":"L2","faiss_index_name":"IVF64,PQ8x4,RFlat","k":5,"nprobe":8,"k_factor":2.0})"); + auto res = idx.Search(gen_fp32(nq, dim, 13), search, nullptr); + REQUIRE(res.has_value()); +} + +#ifdef FAISS_ENABLE_SVS +// SVS Vamana — verify search_window_size is recognized at the SVS leaf branch. +// Compiled only in SVS-enabled builds (e.g. production X86 image). +TEST_CASE("FAISS SVS Vamana: search_window_size passed through", "[faiss_vanilla]") { + const size_t nb = 4096, dim = 16, nq = 4; + auto version = knowhere::Version::GetCurrentVersion(); + auto idx = knowhere::IndexFactory::Instance() + .Create(knowhere::IndexEnum::INDEX_FAISS, version.VersionNumber()) + .value(); + knowhere::Json build = knowhere::Json::parse(R"({"metric_type":"L2","faiss_index_name":"SVSVamana64","dim":16})"); + REQUIRE(idx.Build(gen_fp32(nb, dim), build) == knowhere::Status::success); + + knowhere::Json search = + knowhere::Json::parse(R"({"metric_type":"L2","faiss_index_name":"SVSVamana64","k":5,"search_window_size":32})"); + auto res = idx.Search(gen_fp32(nq, dim, 19), search, nullptr); + REQUIRE(res.has_value()); +} +#endif + +// Stringified numeric/boolean values should be accepted (matches Knowhere's +// native Config::FormatAndCheck convention for declared fields). +TEST_CASE("FAISS: stringified nprobe is coerced to number", "[faiss_vanilla]") { + const size_t nb = 2000, dim = 16, nq = 4; + auto version = knowhere::Version::GetCurrentVersion(); + auto idx = knowhere::IndexFactory::Instance() + .Create(knowhere::IndexEnum::INDEX_FAISS, version.VersionNumber()) + .value(); + knowhere::Json jb = knowhere::Json::parse(R"({"metric_type":"L2","faiss_index_name":"IVF64,Flat","dim":16})"); + REQUIRE(idx.Build(gen_fp32(nb, dim), jb) == knowhere::Status::success); + + knowhere::Json jq = + knowhere::Json::parse(R"({"metric_type":"L2","faiss_index_name":"IVF64,Flat","k":5,"nprobe":"16"})"); + auto res = idx.Search(gen_fp32(nq, dim, 3), jq, nullptr); + REQUIRE(res.has_value()); +} + +TEST_CASE("FAISS: stringified bool is coerced", "[faiss_vanilla]") { + const size_t nb = 1000, dim = 16, nq = 1; + auto version = knowhere::Version::GetCurrentVersion(); + auto idx = knowhere::IndexFactory::Instance() + .Create(knowhere::IndexEnum::INDEX_FAISS, version.VersionNumber()) + .value(); + knowhere::Json jb = knowhere::Json::parse(R"({"metric_type":"L2","faiss_index_name":"HNSW16,Flat","dim":16})"); + REQUIRE(idx.Build(gen_fp32(nb, dim), jb) == knowhere::Status::success); + + knowhere::Json jq = knowhere::Json::parse( + R"({"metric_type":"L2","faiss_index_name":"HNSW16,Flat","k":5,"check_relative_distance":"false"})"); + auto res = idx.Search(gen_fp32(nq, dim, 3), jq, nullptr); + REQUIRE(res.has_value()); +} + +TEST_CASE("FAISS: unparseable string param is rejected with clear error", "[faiss_vanilla]") { + const size_t nb = 500, dim = 16; + auto version = knowhere::Version::GetCurrentVersion(); + auto idx = knowhere::IndexFactory::Instance() + .Create(knowhere::IndexEnum::INDEX_FAISS, version.VersionNumber()) + .value(); + knowhere::Json jb = knowhere::Json::parse( + R"({"metric_type":"L2","faiss_index_name":"IVF64,Flat","dim":16,"nprobe":"not_a_number"})"); + auto st = idx.Build(gen_fp32(nb, dim), jb); + REQUIRE(st == knowhere::Status::invalid_args); +} + +// Standalone IndexPQ — verify polysemous_ht is recognized at the PQ leaf branch. +TEST_CASE("FAISS standalone PQ: polysemous_ht passed through", "[faiss_vanilla]") { + const size_t nb = 4096, dim = 16, nq = 4; + auto version = knowhere::Version::GetCurrentVersion(); + auto idx = knowhere::IndexFactory::Instance() + .Create(knowhere::IndexEnum::INDEX_FAISS, version.VersionNumber()) + .value(); + knowhere::Json build = knowhere::Json::parse(R"({"metric_type":"L2","faiss_index_name":"PQ8x4","dim":16})"); + REQUIRE(idx.Build(gen_fp32(nb, dim), build) == knowhere::Status::success); + + knowhere::Json search = + knowhere::Json::parse(R"({"metric_type":"L2","faiss_index_name":"PQ8x4","k":5,"polysemous_ht":24})"); + auto res = idx.Search(gen_fp32(nq, dim, 17), search, nullptr); + REQUIRE(res.has_value()); +} diff --git a/thirdparty/faiss/faiss/cppcontrib/knowhere/SearchParamsDispatch.cpp b/thirdparty/faiss/faiss/cppcontrib/knowhere/SearchParamsDispatch.cpp new file mode 100644 index 000000000..21726bdf4 --- /dev/null +++ b/thirdparty/faiss/faiss/cppcontrib/knowhere/SearchParamsDispatch.cpp @@ -0,0 +1,299 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef FAISS_ENABLE_SVS +#include +#endif + +namespace faiss::cppcontrib::knowhere { + +namespace { + +// The wrapper subclasses SearchParametersPreTransform and +// IndexRefineSearchParameters hold raw (non-owning) pointers to the nested +// sub-params (per faiss header comments: "non owning"). When we build them via +// make_search_params we want the caller to own the whole tree with a single +// unique_ptr, so we bundle the inner unique_ptr into these Owning* subclasses. +// The Owning flavors decay to the base faiss types for all consumers (faiss's +// search() implementations see them as the base class). +struct OwningSearchParametersPreTransform + : ::faiss::SearchParametersPreTransform { + std::unique_ptr<::faiss::SearchParameters> inner_owned; +}; +struct OwningIndexRefineSearchParameters + : ::faiss::IndexRefineSearchParameters { + std::unique_ptr<::faiss::SearchParameters> inner_owned; +}; + +bool try_set_ivf( + ::faiss::SearchParametersIVF* p, + const std::string& name, + double val) { + if (name == "nprobe") { + p->nprobe = static_cast(val); + return true; + } + if (name == "max_codes") { + p->max_codes = static_cast(val); + return true; + } + return false; +} + +bool try_set_hnsw( + ::faiss::SearchParametersHNSW* p, + const std::string& name, + double val) { + if (name == "efSearch") { + p->efSearch = static_cast(val); + return true; + } + if (name == "check_relative_distance") { + p->check_relative_distance = val != 0.0; + return true; + } + if (name == "bounded_queue") { + p->bounded_queue = val != 0.0; + return true; + } + return false; +} + +bool try_set_pq( + ::faiss::SearchParametersPQ* p, + const std::string& name, + double val) { + if (name == "polysemous_ht") { + p->polysemous_ht = static_cast(val); + return true; + } + if (name == "search_type") { + p->search_type = static_cast<::faiss::IndexPQ::Search_type_t>( + static_cast(val)); + return true; + } + return false; +} + +#ifdef FAISS_ENABLE_SVS +bool try_set_svs_vamana( + ::faiss::SearchParametersSVSVamana* p, + const std::string& name, + double val) { + if (name == "search_window_size") { + p->search_window_size = static_cast(val); + return true; + } + if (name == "search_buffer_capacity") { + p->search_buffer_capacity = static_cast(val); + return true; + } + return false; +} +#endif + +} // namespace + +std::unique_ptr<::faiss::SearchParameters> make_search_params( + const ::faiss::Index* index) { + // Wrapper: PreTransform (OPQ, PCA, etc.). Recurse to inner index; no own + // knobs. + if (auto* pt = dynamic_cast(index)) { + auto inner = make_search_params(pt->index); + auto p = std::make_unique(); + p->index_params = inner.get(); + p->inner_owned = std::move(inner); + return p; + } + // Wrapper: Refine (RFlat, Refine(...)). Has k_factor knob at this layer. + if (auto* rfn = dynamic_cast(index)) { + auto inner = make_search_params(rfn->base_index); + auto p = std::make_unique(); + p->base_index_params = inner.get(); + p->inner_owned = std::move(inner); + return p; + } + // Leaf families (order matters only for disjoint casts; these are mutually + // exclusive). + if (dynamic_cast(index)) { + return std::make_unique<::faiss::SearchParametersHNSW>(); + } + if (dynamic_cast(index)) { + return std::make_unique<::faiss::SearchParametersIVF>(); + } + if (dynamic_cast(index)) { + return std::make_unique<::faiss::SearchParametersPQ>(); + } +#ifdef FAISS_ENABLE_SVS + if (dynamic_cast(index)) { + // Catches IndexSVSVamana, IndexSVSVamanaLVQ, IndexSVSVamanaLeanVec. + return std::make_unique<::faiss::SearchParametersSVSVamana>(); + } +#endif + return std::make_unique<::faiss::SearchParameters>(); +} + +std::unique_ptr<::faiss::SearchParameters> make_search_params( + const ::faiss::IndexBinary* index) { + // IndexBinaryIVF::search uses dynamic_cast to SearchParametersIVF; giving + // it a plain SearchParameters would fail the check. Binary IVF also does + // not honor IDSelector, so callers should not set `sel` on the returned + // object for BIVF. + if (dynamic_cast(index)) { + return std::make_unique<::faiss::SearchParametersIVF>(); + } + return std::make_unique<::faiss::SearchParameters>(); +} + +// ---------- supported-name whitelists (query-only, no mutation) ---------- + +namespace { + +// Names accepted by try_set_ivf above. Keep in sync. +const std::set& ivf_names() { + static const std::set kNames = {"nprobe", "max_codes"}; + return kNames; +} + +const std::set& hnsw_names() { + static const std::set kNames = { + "efSearch", "check_relative_distance", "bounded_queue"}; + return kNames; +} + +const std::set& pq_names() { + static const std::set kNames = { + "polysemous_ht", "search_type"}; + return kNames; +} + +#ifdef FAISS_ENABLE_SVS +const std::set& svs_vamana_names() { + static const std::set kNames = { + "search_window_size", "search_buffer_capacity"}; + return kNames; +} +#endif + +} // namespace + +std::set supported_search_params(const ::faiss::Index* index) { + // Wrappers: union of own knobs and inner supported set. + if (auto* pt = dynamic_cast(index)) { + return supported_search_params(pt->index); // no own knobs + } + if (auto* rfn = dynamic_cast(index)) { + auto out = supported_search_params(rfn->base_index); + out.insert("k_factor"); + return out; + } + if (dynamic_cast(index)) { + return hnsw_names(); + } + if (dynamic_cast(index)) { + return ivf_names(); + } + if (dynamic_cast(index)) { + return pq_names(); + } +#ifdef FAISS_ENABLE_SVS + if (dynamic_cast(index)) { + return svs_vamana_names(); + } +#endif + return {}; // plain Index: only sel, no named knobs +} + +std::set supported_search_params( + const ::faiss::IndexBinary* index) { + if (dynamic_cast(index)) { + return ivf_names(); + } + return {}; +} + +std::set supported_build_param_names() { + // Mirror the hardcoded if-chain in + // faiss::ParameterSpace::set_index_parameter (AutoTune.cpp). + // "quantizer_" is handled via prefix in is_supported_build_param. + static const std::set kNames = { + "nprobe", + "ht", + "k_factor", + "max_codes", + "prune_headroom", + "efConstruction", + "efSearch", + }; + return kNames; +} + +bool is_supported_build_param(const std::string& name) { + if (supported_build_param_names().count(name)) { + return true; + } + // ParameterSpace recursively forwards keys starting with "quantizer_" into + // the coarse quantizer of an IVF index. Validate the suffix against the + // same list. + constexpr const char kQuantizerPrefix[] = "quantizer_"; + constexpr size_t kPrefixLen = sizeof(kQuantizerPrefix) - 1; + if (name.compare(0, kPrefixLen, kQuantizerPrefix) == 0) { + return is_supported_build_param(name.substr(kPrefixLen)); + } + return false; +} + +// ---------- runtime setter (walks into wrappers) ---------- + +bool try_set_search_param( + ::faiss::SearchParameters* params, + const std::string& name, + double val) { + // Wrappers first: try this layer's own knobs, then recurse to inner params. + if (auto* pt = + dynamic_cast<::faiss::SearchParametersPreTransform*>(params)) { + // PreTransform has no own knobs; forward to inner. + return pt->index_params && + try_set_search_param(pt->index_params, name, val); + } + if (auto* rfn = + dynamic_cast<::faiss::IndexRefineSearchParameters*>(params)) { + if (name == "k_factor") { + rfn->k_factor = static_cast(val); + return true; + } + return rfn->base_index_params && + try_set_search_param(rfn->base_index_params, name, val); + } + // Leaves. + if (auto* ivf = dynamic_cast<::faiss::SearchParametersIVF*>(params)) { + return try_set_ivf(ivf, name, val); + } + if (auto* hnsw = dynamic_cast<::faiss::SearchParametersHNSW*>(params)) { + return try_set_hnsw(hnsw, name, val); + } + if (auto* pq = dynamic_cast<::faiss::SearchParametersPQ*>(params)) { + return try_set_pq(pq, name, val); + } +#ifdef FAISS_ENABLE_SVS + if (auto* svs = dynamic_cast<::faiss::SearchParametersSVSVamana*>(params)) { + return try_set_svs_vamana(svs, name, val); + } +#endif + return false; +} + +} // namespace faiss::cppcontrib::knowhere diff --git a/thirdparty/faiss/faiss/cppcontrib/knowhere/SearchParamsDispatch.h b/thirdparty/faiss/faiss/cppcontrib/knowhere/SearchParamsDispatch.h new file mode 100644 index 000000000..d90d228e3 --- /dev/null +++ b/thirdparty/faiss/faiss/cppcontrib/knowhere/SearchParamsDispatch.h @@ -0,0 +1,87 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +// Generic per-family SearchParameters helpers. Intended to be upstreamable: +// the knowhere project uses it today, but the design is index-family aware +// yet metric/config agnostic — it could live in main faiss if accepted. +// +// Motivation: a C++ caller wanting a per-request SearchParameters object for +// an arbitrary faiss::Index currently has to hand-roll a family dispatch +// (IVF -> SearchParametersIVF, HNSW -> SearchParametersHNSW, etc.) and +// recurse through wrapper indexes (PreTransform, Refine). This header +// centralizes that dispatch and exposes two primitives: +// +// 1. make_search_params(index) +// Returns a unique_ptr of the correct concrete type +// for the given index, including recursive inner params for wrapper +// indexes. Ownership of nested params is held inside the returned +// object so the caller can treat it as a single unique_ptr. +// +// 2. try_set_search_param(params, name, value) +// Sets a named runtime knob on the given SearchParameters object, +// walking into nested sub-params for wrapper classes. Returns whether +// the name was recognized. Intended for loops over user-supplied +// key/value config — caller handles "unknown key -> error". + +#pragma once + +#include +#include +#include + +namespace faiss { +struct Index; +struct IndexBinary; +struct SearchParameters; +} // namespace faiss + +namespace faiss::cppcontrib::knowhere { + +// ---------- Search-param dispatch ---------- + +// Construct the appropriate SearchParameters subclass for the given index +// family. For wrapper indexes (PreTransform, Refine) this recurses into the +// inner index. Ownership of any nested SearchParameters is held by the returned +// unique_ptr. +std::unique_ptr<::faiss::SearchParameters> make_search_params( + const ::faiss::Index* index); + +// Binary variant. IndexBinaryIVF needs SearchParametersIVF; everything else +// uses base. +std::unique_ptr<::faiss::SearchParameters> make_search_params( + const ::faiss::IndexBinary* index); + +// Set a named runtime knob. Walks into PreTransform / Refine wrappers. +// Returns true if recognized and applied by some layer, false otherwise. +// double is used as the common value type (matches faiss::ParameterSpace). +bool try_set_search_param( + ::faiss::SearchParameters* params, + const std::string& name, + double val); + +// Returns the whitelist of search-time parameter names recognized by +// try_set_search_param for this index. Includes wrapper-level knobs +// (e.g. k_factor for IndexRefine) plus the inner family's knobs. +// Callers should use this to pre-validate user-supplied params. +std::set supported_search_params(const ::faiss::Index* index); + +std::set supported_search_params( + const ::faiss::IndexBinary* index); + +// ---------- Build-param dispatch ---------- + +// Returns whether faiss::ParameterSpace::set_index_parameter would recognize +// the given name (i.e. it appears in faiss's own hardcoded if-chain, including +// the "quantizer_" prefix for recursing into coarse quantizers). Useful to pre- +// validate user-supplied build-time knobs before forwarding to ParameterSpace. +bool is_supported_build_param(const std::string& name); + +// The fixed set of base names recognized by +// ParameterSpace::set_index_parameter. Note: "quantizer_" where is +// one of these is also supported via ParameterSpace's recursion — use +// is_supported_build_param for that case. +std::set supported_build_param_names(); + +} // namespace faiss::cppcontrib::knowhere