diff --git a/include/knowhere/index/index_table.h b/include/knowhere/index/index_table.h index 24ef5ccd6..e05419c76 100644 --- a/include/knowhere/index/index_table.h +++ b/include/knowhere/index/index_table.h @@ -58,6 +58,10 @@ static std::set> legal_knowhere_index = { {IndexEnum::INDEX_FAISS_IVFSQ_CC, VecType::VECTOR_BFLOAT16}, // {IndexEnum::INDEX_FAISS_IVFSQ_CC, VecType::VECTOR_INT8}, + {IndexEnum::INDEX_FAISS_IVFRABITQ, VecType::VECTOR_FLOAT}, + {IndexEnum::INDEX_FAISS_IVFRABITQ, VecType::VECTOR_FLOAT16}, + {IndexEnum::INDEX_FAISS_IVFRABITQ, VecType::VECTOR_BFLOAT16}, + // gpu index {IndexEnum::INDEX_GPU_BRUTEFORCE, VecType::VECTOR_FLOAT}, {IndexEnum::INDEX_GPU_IVFFLAT, VecType::VECTOR_FLOAT}, @@ -108,6 +112,7 @@ static std::set legal_support_mmap_knowhere_index = { IndexEnum::INDEX_FAISS_SCANN, IndexEnum::INDEX_FAISS_IVFSQ8, IndexEnum::INDEX_FAISS_IVFSQ_CC, + IndexEnum::INDEX_FAISS_IVFRABITQ, // hnsw IndexEnum::INDEX_HNSW, diff --git a/src/index/hnsw/faiss_hnsw.cc b/src/index/hnsw/faiss_hnsw.cc index 8418c22f5..e836ac1f6 100644 --- a/src/index/hnsw/faiss_hnsw.cc +++ b/src/index/hnsw/faiss_hnsw.cc @@ -40,6 +40,7 @@ #include "index/hnsw/impl/IndexConditionalWrapper.h" #include "index/hnsw/impl/IndexHNSWWrapper.h" #include "index/hnsw/impl/IndexWrapperCosine.h" +#include "index/refine/refine_utils.h" #include "io/memory_io.h" #include "knowhere/bitsetview_idselector.h" #include "knowhere/comp/index_param.h" @@ -2034,205 +2035,6 @@ class BaseFaissRegularIndexHNSWFlatNodeTemplateWithSearchFallback : public HNSWI } }; -namespace { - -// a supporting function -expected -get_sq_quantizer_type(const std::string& sq_type) { - std::map sq_types = { - {"sq6", faiss::ScalarQuantizer::QT_6bit}, - {"sq8", faiss::ScalarQuantizer::QT_8bit}, - {"fp16", faiss::ScalarQuantizer::QT_fp16}, - {"bf16", faiss::ScalarQuantizer::QT_bf16}, - {"int8", faiss::ScalarQuantizer::QT_8bit_direct_signed}}; - - // todo: tolower - auto sq_type_tolower = str_to_lower(sq_type); - auto itr = sq_types.find(sq_type_tolower); - if (itr == sq_types.cend()) { - return expected::Err( - Status::invalid_args, fmt::format("invalid scalar quantizer type ({})", sq_type_tolower)); - } - - return itr->second; -} - -/* -// checks whether an index contains a refiner, suitable for a given data format -std::optional whether_refine_is_datatype( - const faiss::Index* index, - const DataFormatEnum data_format -) { - if (index == nullptr) { - return {}; - } - - const faiss::IndexRefine* const index_refine = dynamic_cast(index); - if (index_refine == nullptr) { - return false; - } - - switch(data_format) { - case DataFormatEnum::fp32: - return (dynamic_cast(index_refine->refine_index) != nullptr); - case DataFormatEnum::fp16: - { - const auto* const index_sq = dynamic_cast(index_refine->refine_index); return (index_sq != nullptr && index_sq->sq.qtype == -faiss::ScalarQuantizer::QT_fp16); - } - case DataFormatEnum::bf16: - { - const auto* const index_sq = dynamic_cast(index_refine->refine_index); return (index_sq != nullptr && index_sq->sq.qtype == -faiss::ScalarQuantizer::QT_bf16); - } - default: - return {}; - } -} -*/ - -expected -is_flat_refine(const std::optional& refine_type) { - // grab a type of a refine index - if (!refine_type.has_value()) { - return true; - }; - - // todo: tolower - std::string refine_type_tolower = str_to_lower(refine_type.value()); - if (refine_type_tolower == "fp32" || refine_type_tolower == "flat") { - return true; - }; - - // parse - auto refine_sq_type = get_sq_quantizer_type(refine_type_tolower); - if (!refine_sq_type.has_value()) { - LOG_KNOWHERE_ERROR_ << "Invalid refine type: " << refine_type.value(); - return expected::Err(Status::invalid_args, fmt::format("invalid refine type ({})", refine_type.value())); - } - - return false; -} - -bool -has_lossless_quant(const expected& quant_type, DataFormatEnum dataFormat) { - if (!quant_type.has_value()) { - return false; - } - - auto quant = quant_type.value(); - switch (dataFormat) { - case DataFormatEnum::fp32: - return false; - case DataFormatEnum::fp16: - return quant == faiss::ScalarQuantizer::QuantizerType::QT_fp16; - case DataFormatEnum::bf16: - return quant == faiss::ScalarQuantizer::QuantizerType::QT_bf16; - case DataFormatEnum::int8: - return quant == faiss::ScalarQuantizer::QuantizerType::QT_8bit_direct_signed; - default: - return false; - } -} - -bool -has_lossless_refine_index(const FaissHnswConfig& hnsw_cfg, DataFormatEnum dataFormat) { - bool has_refine = hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value(); - if (has_refine) { - expected flat_refine = is_flat_refine(hnsw_cfg.refine_type); - if (flat_refine.has_value() && flat_refine.value()) { - return true; - } - - auto sq_refine_type = get_sq_quantizer_type(hnsw_cfg.refine_type.value()); - return has_lossless_quant(sq_refine_type, dataFormat); - } - return false; -} - -// pick a refine index -expected> -pick_refine_index(const DataFormatEnum data_format, const std::optional& refine_type, - std::unique_ptr&& hnsw_index) { - // yes - - // grab a type of a refine index - expected is_fp32_flat = is_flat_refine(refine_type); - if (!is_fp32_flat.has_value()) { - return expected>::Err(Status::invalid_args, ""); - } - - const bool is_fp32_flat_v = is_fp32_flat.value(); - - // check input data_format - if (data_format == DataFormatEnum::fp16) { - // make sure that we're using fp16 refine - auto refine_sq_type = get_sq_quantizer_type(refine_type.value()); - if (!(refine_sq_type.has_value() && - (refine_sq_type.value() != faiss::ScalarQuantizer::QT_bf16 && !is_fp32_flat_v))) { - LOG_KNOWHERE_ERROR_ << "fp16 input data does not accept bf16 or fp32 as a refine index."; - return expected>::Err( - Status::invalid_args, "fp16 input data does not accept bf16 or fp32 as a refine index."); - } - } - - if (data_format == DataFormatEnum::bf16) { - // make sure that we're using bf16 refine - auto refine_sq_type = get_sq_quantizer_type(refine_type.value()); - if (!(refine_sq_type.has_value() && - (refine_sq_type.value() != faiss::ScalarQuantizer::QT_fp16 && !is_fp32_flat_v))) { - LOG_KNOWHERE_ERROR_ << "bf16 input data does not accept fp16 or fp32 as a refine index."; - return expected>::Err( - Status::invalid_args, "bf16 input data does not accept fp16 or fp32 as a refine index."); - } - } - - // build - std::unique_ptr local_hnsw_index = std::move(hnsw_index); - - // either build flat or sq - if (is_fp32_flat_v) { - // build IndexFlat as a refine - auto refine_index = std::make_unique(local_hnsw_index.get()); - - // let refine_index to own everything - refine_index->own_fields = true; - local_hnsw_index.release(); - - // reassign - return refine_index; - } else { - // being IndexScalarQuantizer as a refine - auto refine_sq_type = get_sq_quantizer_type(refine_type.value()); - - // a redundant check - if (!refine_sq_type.has_value()) { - LOG_KNOWHERE_ERROR_ << "Invalid refine type: " << refine_type.value(); - return expected>::Err( - Status::invalid_args, fmt::format("invalid refine type ({})", refine_type.value())); - } - - // create an sq - auto sq_refine = std::make_unique( - local_hnsw_index->storage->d, refine_sq_type.value(), local_hnsw_index->storage->metric_type); - - auto refine_index = std::make_unique(local_hnsw_index.get(), sq_refine.get()); - - // let refine_index to own everything - refine_index->own_refine_index = true; - refine_index->own_fields = true; - local_hnsw_index.release(); - sq_refine.release(); - - // reassign - return refine_index; - } -} - -} // namespace - // class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode { public: @@ -2300,7 +2102,10 @@ class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode { if (hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value()) { // yes - auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index)); + const auto hnsw_d = hnsw_index->storage->d; + const auto hnsw_metric_type = hnsw_index->storage->metric_type; + auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index), + hnsw_d, hnsw_metric_type); if (!final_index_cnd.has_value()) { return Status::invalid_args; } @@ -2368,7 +2173,7 @@ class BaseFaissRegularIndexHNSWSQNodeTemplate : public BaseFaissRegularIndexHNSW return true; } - return has_lossless_refine_index(hnsw_sq_cfg, datatype_v); + return has_lossless_refine_index(hnsw_sq_cfg.refine, hnsw_sq_cfg.refine_type, datatype_v); } }; @@ -2449,7 +2254,10 @@ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode { std::unique_ptr final_index; if (hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value()) { // yes - auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index)); + const auto hnsw_d = hnsw_index->storage->d; + const auto hnsw_metric_type = hnsw_index->storage->metric_type; + auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index), + hnsw_d, hnsw_metric_type); if (!final_index_cnd.has_value()) { return Status::invalid_args; } @@ -2640,7 +2448,7 @@ class BaseFaissRegularIndexHNSWPQNodeTemplate : public BaseFaissRegularIndexHNSW static bool StaticHasRawData(const knowhere::BaseConfig& config, const IndexVersion& version) { auto hnsw_cfg = static_cast(config); - return has_lossless_refine_index(hnsw_cfg, datatype_v); + return has_lossless_refine_index(hnsw_cfg.refine, hnsw_cfg.refine_type, datatype_v); } }; @@ -2728,7 +2536,10 @@ class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode { std::unique_ptr final_index; if (hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value()) { // yes - auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index)); + const auto hnsw_d = hnsw_index->storage->d; + const auto hnsw_metric_type = hnsw_index->storage->metric_type; + auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index), + hnsw_d, hnsw_metric_type); if (!final_index_cnd.has_value()) { return Status::invalid_args; } @@ -2920,7 +2731,7 @@ class BaseFaissRegularIndexHNSWPRQNodeTemplate : public BaseFaissRegularIndexHNS static bool StaticHasRawData(const knowhere::BaseConfig& config, const IndexVersion& version) { auto hnsw_cfg = static_cast(config); - return has_lossless_refine_index(hnsw_cfg, datatype_v); + return has_lossless_refine_index(hnsw_cfg.refine, hnsw_cfg.refine_type, datatype_v); } }; diff --git a/src/index/ivf/ivf.cc b/src/index/ivf/ivf.cc index baf01df88..351846062 100644 --- a/src/index/ivf/ivf.cc +++ b/src/index/ivf/ivf.cc @@ -19,7 +19,6 @@ #include "faiss/IndexIVFPQFastScan.h" #include "faiss/IndexIVFRaBitQ.h" #include "faiss/IndexIVFScalarQuantizerCC.h" -#include "faiss/IndexPreTransform.h" #include "faiss/IndexScaNN.h" #include "faiss/IndexScalarQuantizer.h" #include "faiss/VectorTransform.h" @@ -670,9 +669,15 @@ IvfIndexNode::TrainInternal(const DataSetPtr dataset, std:: if constexpr (std::is_same::value) { const IvfRaBitQConfig& ivf_rabitq_cfg = static_cast(*cfg); auto nlist = MatchNlist(rows, ivf_rabitq_cfg.nlist.value()); - auto qb = ivf_rabitq_cfg.rbq_bits_query.value(); - index = std::make_unique(dim, nlist, qb, metric.value()); + DataFormatEnum data_format = DataType2EnumHelper::value; + + auto result = IndexIVFRaBitQWrapper::create(dim, nlist, ivf_rabitq_cfg, data_format, metric.value()); + if (!result.has_value()) { + return result.error(); + } + + index = std::move(result.value()); index->train(rows, (const float*)data); } index_ = std::move(index); @@ -835,13 +840,36 @@ IvfIndexNode::Search(const DataSetPtr dataset, std::unique_ const IvfRaBitQConfig& ivf_rabitq_cfg = static_cast(*cfg); + // use refine? + bool use_refine = false; + + const bool whether_to_enable_refine = ivf_rabitq_cfg.refine_k.has_value(); + if (const auto wrapper_index = dynamic_cast(index_.get()); + wrapper_index != nullptr) { + const faiss::IndexRefine* refine_index = wrapper_index->get_refine_index(); + use_refine = (refine_index != nullptr); + } + faiss::IVFRaBitQSearchParameters ivf_search_params; ivf_search_params.nprobe = nprobe; ivf_search_params.max_codes = 0; ivf_search_params.sel = id_selector; ivf_search_params.qb = ivf_rabitq_cfg.rbq_bits_query.value_or(0); - index_->search(1, cur_query, k, distances.get() + offset, ids.get() + offset, &ivf_search_params); + if (use_refine && whether_to_enable_refine) { + // yes, use refine + faiss::IndexRefineSearchParameters refine_search_params; + refine_search_params.sel = id_selector; + refine_search_params.k_factor = ivf_rabitq_cfg.refine_k.value_or(1); + refine_search_params.base_index_params = &ivf_search_params; + + index_->search(1, cur_query, k, distances.get() + offset, ids.get() + offset, + &refine_search_params); + } else { + // do not use refine + index_->search(1, cur_query, k, distances.get() + offset, ids.get() + offset, + &ivf_search_params); + } } else { auto cur_query = (const float*)data + index * dim; if (is_cosine) { @@ -964,7 +992,27 @@ IvfIndexNode::RangeSearch(const DataSetPtr dataset, std::un ivf_search_params.sel = id_selector; ivf_search_params.qb = ivf_rabitq_cfg.rbq_bits_query.value_or(0); - index_->range_search(1, cur_query, radius, &res, &ivf_search_params); + // use refine? + bool use_refine = false; + + const bool whether_to_enable_refine = ivf_rabitq_cfg.refine_k.has_value(); + if (const auto wrapper_index = dynamic_cast(index_.get()); + wrapper_index != nullptr) { + const faiss::IndexRefine* refine_index = wrapper_index->get_refine_index(); + use_refine = (refine_index != nullptr); + } + + if (use_refine && whether_to_enable_refine) { + // yes, use refine + faiss::IndexRefineSearchParameters refine_search_params; + refine_search_params.sel = id_selector; + refine_search_params.k_factor = ivf_rabitq_cfg.refine_k.value_or(1); + refine_search_params.base_index_params = &ivf_search_params; + + index_->range_search(1, cur_query, radius, &res, &refine_search_params); + } else { + index_->range_search(1, cur_query, radius, &res, &ivf_search_params); + } } else { auto cur_query = (const float*)xq + index * dim; if (is_cosine) { diff --git a/src/index/ivf/ivf_config.h b/src/index/ivf/ivf_config.h index 93ef77642..436dcffe9 100644 --- a/src/index/ivf/ivf_config.h +++ b/src/index/ivf/ivf_config.h @@ -12,7 +12,13 @@ #ifndef IVF_CONFIG_H #define IVF_CONFIG_H +#include +#include +#include +#include + #include "knowhere/config.h" +#include "knowhere/tolower.h" #include "simd/hook.h" namespace knowhere { @@ -223,6 +229,13 @@ class IvfSqCcConfig : public IvfFlatCcConfig { class IvfRaBitQConfig : public IvfConfig { public: + // whether an index is built with a refine support + CFG_BOOL refine; + // undefined value leads to a search without a refine + CFG_FLOAT refine_k; + // type of refine + CFG_STRING refine_type; + // the value `0` means that the query won't be quantized and will // be processed as is. CFG_INT rbq_bits_query; @@ -233,6 +246,59 @@ class IvfRaBitQConfig : public IvfConfig { .set_range(0, 8) .for_search() .for_range_search(); + KNOWHERE_CONFIG_DECLARE_FIELD(refine) + .description("whether the refine is used during the train") + .set_default(false) + .for_train() + .for_static(); + KNOWHERE_CONFIG_DECLARE_FIELD(refine_k) + .description("refine k") + .set_default(1) + .set_range(1, std::numeric_limits::max()) + .for_search(); + KNOWHERE_CONFIG_DECLARE_FIELD(refine_type) + .description("the type of a refine index") + .allow_empty_without_default() + .for_train() + .for_static(); + } + + Status + CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override { + // check the base class + const auto base_status = IvfConfig::CheckAndAdjust(param_type, err_msg); + if (base_status != Status::success) { + return base_status; + } + + // check our parameters + if (param_type == PARAM_TYPE::TRAIN) { + // check refine + if (refine_type.has_value()) { + if (!WhetherAcceptableRefineType(refine_type.value())) { + std::string msg = + "invalid refine type : " + refine_type.value() + ", optional types are [sq6, sq8, fp16, bf16]"; + return HandleError(err_msg, msg, Status::invalid_args); + } + } + } + return Status::success; + } + + protected: + bool + WhetherAcceptableRefineType(const std::string& refine_type) { + // 'flat' is identical to 'fp32' + std::vector allowed_list = {"sq6", "sq8", "fp16", "bf16", "fp32", "flat"}; + std::string refine_type_tolower = str_to_lower(refine_type); + + for (const auto& allowed : allowed_list) { + if (refine_type_tolower == allowed) { + return true; + } + } + + return false; } }; diff --git a/src/index/ivf/ivfrbq_wrapper.cc b/src/index/ivf/ivfrbq_wrapper.cc index a9f355cc9..d9254d764 100644 --- a/src/index/ivf/ivfrbq_wrapper.cc +++ b/src/index/ivf/ivfrbq_wrapper.cc @@ -11,45 +11,71 @@ #include "index/ivf/ivfrbq_wrapper.h" -#include - #include +#include "faiss/IndexCosine.h" #include "faiss/IndexFlat.h" #include "faiss/IndexPreTransform.h" +#include "faiss/cppcontrib/knowhere/impl/CountSizeIOWriter.h" #include "faiss/index_io.h" +#include "index/refine/refine_utils.h" namespace knowhere { -IndexIVFRaBitQWrapper::IndexIVFRaBitQWrapper(const faiss::idx_t d, const size_t nlist, const uint8_t qb, - faiss::MetricType metric) - : faiss::Index(d, metric) { +expected> +IndexIVFRaBitQWrapper::create(const faiss::idx_t d, const size_t nlist, const IvfRaBitQConfig& ivf_rabitq_cfg, + const DataFormatEnum raw_data_format, const faiss::MetricType metric) { + // the index factory string is either `RR(dim),IVFx,RaBitQ,Refine(y)`, + // or `RR(dim),IVFx,RaBitQ`, depends on the refine parameters + + // create IndexIVFRaBitQ + auto qb = ivf_rabitq_cfg.rbq_bits_query.value(); + auto idx_flat = std::make_unique(d, metric, false); auto idx_ivfrbq = std::make_unique(idx_flat.release(), d, nlist, metric); idx_ivfrbq->own_fields = true; idx_ivfrbq->qb = qb; + // wrap it in an IndexPreTransform auto rr = std::make_unique(d, d); auto idx_rr = std::make_unique(rr.release(), idx_ivfrbq.release()); idx_rr->own_fields = true; - index = std::move(idx_rr); + // create a refiner index, if needed + std::unique_ptr idx_final; + if (ivf_rabitq_cfg.refine.value_or(false) && ivf_rabitq_cfg.refine_type.has_value()) { + // refine is needed + const auto base_d = idx_rr->d; + const auto base_metric_type = idx_rr->metric_type; + auto final_index_cnd = + pick_refine_index(raw_data_format, ivf_rabitq_cfg.refine_type, std::move(idx_rr), base_d, base_metric_type); + if (!final_index_cnd.has_value()) { + return expected>::Err(Status::invalid_args, + "Invalid refine parameters"); + } + + idx_final = std::move(final_index_cnd.value()); + } else { + // refine is not needed + idx_final = std::move(idx_rr); + } - this->is_trained = index->is_trained; - this->is_cosine = index->is_cosine; + auto result = std::make_unique(std::move(idx_final)); + return result; } IndexIVFRaBitQWrapper::IndexIVFRaBitQWrapper(std::unique_ptr&& index_in) : Index{index_in->d, index_in->metric_type}, index{std::move(index_in)} { ntotal = index->ntotal; is_trained = index->is_trained; + is_cosine = index->is_cosine; verbose = index->verbose; metric_arg = index->metric_arg; } std::unique_ptr IndexIVFRaBitQWrapper::from_deserialized(std::unique_ptr&& index_in) { - auto index = std::unique_ptr(new IndexIVFRaBitQWrapper(std::move(index_in))); + auto index = std::make_unique(std::move(index_in)); // check a provided index type auto index_rabitq = index->get_ivfrabitq_index(); @@ -106,7 +132,12 @@ IndexIVFRaBitQWrapper::get_distance_computer() const { faiss::IndexIVFRaBitQ* IndexIVFRaBitQWrapper::get_ivfrabitq_index() { - faiss::IndexPreTransform* index_pt = dynamic_cast(index.get()); + // try refine + faiss::IndexRefine* index_refine = dynamic_cast(index.get()); + faiss::Index* index_for_pt = (index_refine != nullptr) ? index_refine->base_index : index.get(); + + // pre-transform + faiss::IndexPreTransform* index_pt = dynamic_cast(index_for_pt); if (index_pt == nullptr) { return nullptr; } @@ -116,7 +147,12 @@ IndexIVFRaBitQWrapper::get_ivfrabitq_index() { const faiss::IndexIVFRaBitQ* IndexIVFRaBitQWrapper::get_ivfrabitq_index() const { - const faiss::IndexPreTransform* index_pt = dynamic_cast(index.get()); + // try refine + const faiss::IndexRefine* index_refine = dynamic_cast(index.get()); + const faiss::Index* index_for_pt = (index_refine != nullptr) ? index_refine->base_index : index.get(); + + // pre-transform + const faiss::IndexPreTransform* index_pt = dynamic_cast(index_for_pt); if (index_pt == nullptr) { return nullptr; } @@ -124,6 +160,16 @@ IndexIVFRaBitQWrapper::get_ivfrabitq_index() const { return dynamic_cast(index_pt->index); } +faiss::IndexRefine* +IndexIVFRaBitQWrapper::get_refine_index() { + return dynamic_cast(index.get()); +} + +const faiss::IndexRefine* +IndexIVFRaBitQWrapper::get_refine_index() const { + return dynamic_cast(index.get()); +} + size_t IndexIVFRaBitQWrapper::size() const { if (index == nullptr) { @@ -141,7 +187,11 @@ IndexIVFRaBitQWrapper::size() const { std::unique_ptr IndexIVFRaBitQWrapper::getIteratorWorkspace(const float* query_data, const faiss::IVFSearchParameters* ivfsearchParams) const { - const faiss::IndexPreTransform* index_pt = dynamic_cast(index.get()); + // try refine + const faiss::IndexRefine* index_refine = dynamic_cast(index.get()); + faiss::Index* index_for_pt = (index_refine != nullptr) ? index_refine->base_index : index.get(); + + const faiss::IndexPreTransform* index_pt = dynamic_cast(index_for_pt); if (index_pt == nullptr) { return nullptr; } @@ -153,8 +203,24 @@ IndexIVFRaBitQWrapper::getIteratorWorkspace(const float* query_data, // ok, transform the query std::unique_ptr transformed_query(index_pt->apply_chain(1, query_data)); - // create a workspace + // create a workspace. This will make a clone of the transformed_query. auto workspace = index_rbq->getIteratorWorkspace(transformed_query.get(), ivfsearchParams); + + // check if refine exists + if (index_refine != nullptr) { + // create a distance + // index_rbq == index_refine->base_index + + // a regular use case + workspace->dis_refine = + std::unique_ptr(index_refine->refine_index->get_distance_computer()); + // this points to a previously saved clone + workspace->dis_refine->set_query(workspace->query_data.data()); + } else { + // don't use refine + workspace->dis_refine = nullptr; + } + // done return workspace; } diff --git a/src/index/ivf/ivfrbq_wrapper.h b/src/index/ivf/ivfrbq_wrapper.h index 666489fc8..62c60aace 100644 --- a/src/index/ivf/ivfrbq_wrapper.h +++ b/src/index/ivf/ivfrbq_wrapper.h @@ -18,20 +18,30 @@ #include "faiss/Index.h" #include "faiss/IndexIVF.h" #include "faiss/IndexIVFRaBitQ.h" +#include "faiss/IndexRefine.h" +#include "index/ivf/ivf_config.h" +#include "knowhere/expected.h" namespace knowhere { // This is wrapper is needed, bcz we use faiss::IndexPreTransform -// for wrapping faiss::IndexIVFRaBitQ. The problem is that -// IndexPreTransform is a generic class, suitable for any other -// use case as well, so this is wrong to reference IndexPreTransform -// in the ivf.cc file. +// for wrapping faiss::IndexIVFRaBitQ, optionally combined with +// faiss::IndexRefine. +// The problem is that IndexPreTransform is a generic class, suitable +// for any other use case as well, so this is wrong to reference +// IndexPreTransform in the ivf.cc file. struct IndexIVFRaBitQWrapper : faiss::Index { + // this is one of two: + // * faiss::IndexPreTransform + faiss::IndexIVFRaBitQ + // * faiss::IndexPreTransform + faiss::IndexRefine + faiss::IndexIVFRaBitQ std::unique_ptr index; - // this form is for a regular index constructoin - IndexIVFRaBitQWrapper(const faiss::idx_t d, const size_t nlist, const uint8_t qb, - faiss::MetricType metric = faiss::METRIC_L2); + IndexIVFRaBitQWrapper(std::unique_ptr&& index_in); + + static expected> + create(const faiss::idx_t d, const size_t nlist, const IvfRaBitQConfig& ivf_rabitq_cfg, + // this is the data format of the raw data (if the refine is used) + const DataFormatEnum raw_data_format, const faiss::MetricType metric = faiss::METRIC_L2); // this is for the deserialization. // returns nullptr if the provided index type is not the one @@ -62,12 +72,19 @@ struct IndexIVFRaBitQWrapper : faiss::Index { faiss::DistanceComputer* get_distance_computer() const override; - // point to IndexIVFRaBitQ or return nullptr + // point to IndexIVFRaBitQ or return nullptr. + // this may also point to an index, owned by IndexRefine faiss::IndexIVFRaBitQ* get_ivfrabitq_index(); const faiss::IndexIVFRaBitQ* get_ivfrabitq_index() const; + // point to IndexRefine or return nullptr. + faiss::IndexRefine* + get_refine_index(); + const faiss::IndexRefine* + get_refine_index() const; + // return the size of the index size_t size() const; @@ -77,9 +94,6 @@ struct IndexIVFRaBitQWrapper : faiss::Index { void getIteratorNextBatch(faiss::IVFIteratorWorkspace* workspace, size_t current_backup_count) const; - - private: - IndexIVFRaBitQWrapper(std::unique_ptr&& index_in); }; } // namespace knowhere diff --git a/src/index/refine/refine_utils.cc b/src/index/refine/refine_utils.cc new file mode 100644 index 000000000..aed64b503 --- /dev/null +++ b/src/index/refine/refine_utils.cc @@ -0,0 +1,177 @@ +#include "index/refine/refine_utils.h" + +#include +#include +#include +#include +#include +#include + +#include "faiss/IndexRefine.h" +#include "faiss/IndexScalarQuantizer.h" +#include "fmt/format.h" +#include "knowhere/log.h" +#include "knowhere/tolower.h" + +namespace knowhere { + +// a supporting function +expected +get_sq_quantizer_type(const std::string& sq_type) { + std::map sq_types = { + {"sq6", faiss::ScalarQuantizer::QT_6bit}, + {"sq8", faiss::ScalarQuantizer::QT_8bit}, + {"fp16", faiss::ScalarQuantizer::QT_fp16}, + {"bf16", faiss::ScalarQuantizer::QT_bf16}, + {"int8", faiss::ScalarQuantizer::QT_8bit_direct_signed}}; + + // todo: tolower + auto sq_type_tolower = str_to_lower(sq_type); + auto itr = sq_types.find(sq_type_tolower); + if (itr == sq_types.cend()) { + return expected::Err( + Status::invalid_args, fmt::format("invalid scalar quantizer type ({})", sq_type_tolower)); + } + + return itr->second; +} + +expected +is_flat_refine(const std::optional& refine_type) { + // grab a type of a refine index + if (!refine_type.has_value()) { + return true; + }; + + // todo: tolower + std::string refine_type_tolower = str_to_lower(refine_type.value()); + if (refine_type_tolower == "fp32" || refine_type_tolower == "flat") { + return true; + }; + + // parse + auto refine_sq_type = get_sq_quantizer_type(refine_type_tolower); + if (!refine_sq_type.has_value()) { + LOG_KNOWHERE_ERROR_ << "Invalid refine type: " << refine_type.value(); + return expected::Err(Status::invalid_args, fmt::format("invalid refine type ({})", refine_type.value())); + } + + return false; +} + +bool +has_lossless_quant(const expected& quant_type, DataFormatEnum dataFormat) { + if (!quant_type.has_value()) { + return false; + } + + auto quant = quant_type.value(); + switch (dataFormat) { + case DataFormatEnum::fp32: + return false; + case DataFormatEnum::fp16: + return quant == faiss::ScalarQuantizer::QuantizerType::QT_fp16; + case DataFormatEnum::bf16: + return quant == faiss::ScalarQuantizer::QuantizerType::QT_bf16; + case DataFormatEnum::int8: + return quant == faiss::ScalarQuantizer::QuantizerType::QT_8bit_direct_signed; + default: + return false; + } +} + +bool +has_lossless_refine_index(const std::optional& refine, const std::optional& refine_type, + DataFormatEnum dataFormat) { + bool has_refine = refine.value_or(false) && refine_type.has_value(); + if (has_refine) { + expected flat_refine = is_flat_refine(refine_type); + if (flat_refine.has_value() && flat_refine.value()) { + return true; + } + + auto sq_refine_type = get_sq_quantizer_type(refine_type.value()); + return has_lossless_quant(sq_refine_type, dataFormat); + } + return false; +} + +// pick a refine index +expected> +pick_refine_index(const DataFormatEnum data_format, const std::optional& refine_type, + std::unique_ptr&& base_index, const size_t base_d, + const faiss::MetricType base_metric_type) { + // grab a type of a refine index + expected is_fp32_flat = is_flat_refine(refine_type); + if (!is_fp32_flat.has_value()) { + return expected>::Err(Status::invalid_args, ""); + } + + const bool is_fp32_flat_v = is_fp32_flat.value(); + + // check input data_format + if (data_format == DataFormatEnum::fp16) { + // make sure that we're using fp16 refine + auto refine_sq_type = get_sq_quantizer_type(refine_type.value()); + if (!(refine_sq_type.has_value() && + (refine_sq_type.value() != faiss::ScalarQuantizer::QT_bf16 && !is_fp32_flat_v))) { + LOG_KNOWHERE_ERROR_ << "fp16 input data does not accept bf16 or fp32 as a refine index."; + return expected>::Err( + Status::invalid_args, "fp16 input data does not accept bf16 or fp32 as a refine index."); + } + } + + if (data_format == DataFormatEnum::bf16) { + // make sure that we're using bf16 refine + auto refine_sq_type = get_sq_quantizer_type(refine_type.value()); + if (!(refine_sq_type.has_value() && + (refine_sq_type.value() != faiss::ScalarQuantizer::QT_fp16 && !is_fp32_flat_v))) { + LOG_KNOWHERE_ERROR_ << "bf16 input data does not accept fp16 or fp32 as a refine index."; + return expected>::Err( + Status::invalid_args, "bf16 input data does not accept fp16 or fp32 as a refine index."); + } + } + + // build + std::unique_ptr local_index = std::move(base_index); + + // either build flat or sq + if (is_fp32_flat_v) { + // build IndexFlat as a refine + auto refine_index = std::make_unique(local_index.get()); + + // let refine_index to own everything + refine_index->own_fields = true; + local_index.release(); + + // reassign + return refine_index; + } else { + // being IndexScalarQuantizer as a refine + auto refine_sq_type = get_sq_quantizer_type(refine_type.value()); + + // a redundant check + if (!refine_sq_type.has_value()) { + LOG_KNOWHERE_ERROR_ << "Invalid refine type: " << refine_type.value(); + return expected>::Err( + Status::invalid_args, fmt::format("invalid refine type ({})", refine_type.value())); + } + + // create an sq + auto sq_refine = + std::make_unique(base_d, refine_sq_type.value(), base_metric_type); + + auto refine_index = std::make_unique(local_index.get(), sq_refine.get()); + + // let refine_index to own everything + refine_index->own_refine_index = true; + refine_index->own_fields = true; + local_index.release(); + sq_refine.release(); + + // reassign + return refine_index; + } +} + +} // namespace knowhere diff --git a/src/index/refine/refine_utils.h b/src/index/refine/refine_utils.h new file mode 100644 index 000000000..21d9b768c --- /dev/null +++ b/src/index/refine/refine_utils.h @@ -0,0 +1,36 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "faiss/impl/ScalarQuantizer.h" +#include "knowhere/expected.h" +#include "knowhere/operands.h" + +namespace knowhere { + +expected +get_sq_quantizer_type(const std::string& sq_type); + +expected +is_flat_refine(const std::optional& refine_type); + +bool +has_lossless_quant(const expected& quant_type, DataFormatEnum dataFormat); + +bool +has_lossless_refine_index(const std::optional& refine, const std::optional& refine_type, + DataFormatEnum dataFormat); + +expected> +pick_refine_index(const DataFormatEnum data_format, const std::optional& refine_type, + std::unique_ptr&& base_index, + // These two could be borrowed from base_index. But it seems that + // for HNSW these things are borrowed from base_index.storage. + // So, let's provide these externally + const size_t base_d, const faiss::MetricType base_metric_type); + +} // namespace knowhere diff --git a/tests/ut/test_get_vector.cc b/tests/ut/test_get_vector.cc index 710109ce9..f01e8d634 100644 --- a/tests/ut/test_get_vector.cc +++ b/tests/ut/test_get_vector.cc @@ -172,6 +172,13 @@ TEST_CASE("Test Float Get Vector By Ids", "[Float GetVectorByIds]") { auto ivfrabitq_gen = ivfflat_gen; + auto ivfrabitq_refine_flat_gen = [ivfrabitq_gen]() { + knowhere::Json json = ivfrabitq_gen(); + json["refine"] = true; + json["refine_type"] = "FLAT"; + return json; + }; + SECTION("Test float index") { using std::make_tuple; auto [name, gen] = GENERATE_REF(table>( @@ -183,7 +190,8 @@ TEST_CASE("Test Float Get Vector By Ids", "[Float GetVectorByIds]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen), make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2), make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), - make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFRABITQ, ivfrabitq_gen)})); + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFRABITQ, ivfrabitq_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFRABITQ, ivfrabitq_refine_flat_gen)})); auto idx_expected = knowhere::IndexFactory::Instance().Create(name, version); if (name == knowhere::IndexEnum::INDEX_FAISS_SCANN) { diff --git a/tests/ut/test_iterator.cc b/tests/ut/test_iterator.cc index f718da298..d53a31c2a 100644 --- a/tests/ut/test_iterator.cc +++ b/tests/ut/test_iterator.cc @@ -204,6 +204,13 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") { auto ivfrabitq_gen = ivf_base_gen; + auto ivfrabitq_refine_flat_gen = [ivfrabitq_gen] { + knowhere::Json json = ivfrabitq_gen(); + json["refine"] = true; + json["refine_type"] = "FLAT"; + return json; + }; + auto rand = GENERATE(1, 2); const auto train_ds = GenDataSet(nb, dim, rand); @@ -228,7 +235,8 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") { // make_tuple(knowhere::IndexEnum::INDEX_HNSW_PRQ, hnsw_prq_gen), make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen), make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2), - make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFRABITQ, ivfrabitq_gen)})); + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFRABITQ, ivfrabitq_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFRABITQ, ivfrabitq_refine_flat_gen)})); auto idx = knowhere::IndexFactory::Instance().Create(name, version).value(); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); @@ -316,7 +324,8 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") { // make_tuple(knowhere::IndexEnum::INDEX_HNSW_PRQ, hnsw_prq_gen), make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen), make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2), - make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFRABITQ, ivfrabitq_gen)})); + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFRABITQ, ivfrabitq_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFRABITQ, ivfrabitq_refine_flat_gen)})); auto idx = knowhere::IndexFactory::Instance().Create(name, version).value(); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); @@ -364,7 +373,8 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") { // make_tuple(knowhere::IndexEnum::INDEX_HNSW_PRQ, hnsw_prq_gen), make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen), make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2), - make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFRABITQ, ivfrabitq_gen)})); + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFRABITQ, ivfrabitq_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFRABITQ, ivfrabitq_refine_flat_gen)})); auto idx = knowhere::IndexFactory::Instance().Create(name, version).value(); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); diff --git a/tests/ut/test_search.cc b/tests/ut/test_search.cc index bef0a546a..67ed2d359 100644 --- a/tests/ut/test_search.cc +++ b/tests/ut/test_search.cc @@ -148,6 +148,13 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { return json; }; + auto ivfrabitq_refine_flat_gen = [ivfrabitq_gen]() { + knowhere::Json json = ivfrabitq_gen(); + json["refine"] = true; + json["refine_type"] = "FLAT"; + return json; + }; + const auto train_ds = GenDataSet(nb, dim); const auto query_ds = GenDataSet(nq, dim); @@ -173,7 +180,8 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2), make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_gen), - make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFRABITQ, ivfrabitq_gen)})); + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFRABITQ, ivfrabitq_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFRABITQ, ivfrabitq_refine_flat_gen)})); knowhere::BinarySet bs; // build process { @@ -265,7 +273,8 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2), make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_gen), - make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFRABITQ, ivfrabitq_gen)})); + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFRABITQ, ivfrabitq_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFRABITQ, ivfrabitq_refine_flat_gen)})); auto idx_expected = knowhere::IndexFactory::Instance().Create(name, version); if (name == knowhere::IndexEnum::INDEX_FAISS_SCANN) { // need to check cpu model for scann @@ -478,7 +487,8 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2), make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_gen), - make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFRABITQ, ivfrabitq_gen)})); + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFRABITQ, ivfrabitq_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFRABITQ, ivfrabitq_refine_flat_gen)})); auto idx_expected = knowhere::IndexFactory::Instance().Create(name, version); if (name == knowhere::IndexEnum::INDEX_FAISS_SCANN) {