Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/knowhere/index/index_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ static std::set<std::pair<std::string, VecType>> legal_knowhere_index = {
{IndexEnum::INDEX_FAISS_IVFSQ_CC, VecType::VECTOR_BFLOAT16},
// {IndexEnum::INDEX_FAISS_IVFSQ_CC, VecType::VECTOR_INT8},

{IndexEnum::INDEX_FAISS_IVFRABITQ, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_FAISS_IVFRABITQ, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_FAISS_IVFRABITQ, VecType::VECTOR_BFLOAT16},

// gpu index
{IndexEnum::INDEX_GPU_BRUTEFORCE, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_GPU_IVFFLAT, VecType::VECTOR_FLOAT},
Expand Down Expand Up @@ -108,6 +112,7 @@ static std::set<std::string> legal_support_mmap_knowhere_index = {
IndexEnum::INDEX_FAISS_SCANN,
IndexEnum::INDEX_FAISS_IVFSQ8,
IndexEnum::INDEX_FAISS_IVFSQ_CC,
IndexEnum::INDEX_FAISS_IVFRABITQ,

// hnsw
IndexEnum::INDEX_HNSW,
Expand Down
221 changes: 16 additions & 205 deletions src/index/hnsw/faiss_hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include "index/hnsw/impl/IndexConditionalWrapper.h"
#include "index/hnsw/impl/IndexHNSWWrapper.h"
#include "index/hnsw/impl/IndexWrapperCosine.h"
#include "index/refine/refine_utils.h"
#include "io/memory_io.h"
#include "knowhere/bitsetview_idselector.h"
#include "knowhere/comp/index_param.h"
Expand Down Expand Up @@ -2034,205 +2035,6 @@ class BaseFaissRegularIndexHNSWFlatNodeTemplateWithSearchFallback : public HNSWI
}
};

namespace {

// a supporting function
expected<faiss::ScalarQuantizer::QuantizerType>
get_sq_quantizer_type(const std::string& sq_type) {
std::map<std::string, faiss::ScalarQuantizer::QuantizerType> sq_types = {
{"sq6", faiss::ScalarQuantizer::QT_6bit},
{"sq8", faiss::ScalarQuantizer::QT_8bit},
{"fp16", faiss::ScalarQuantizer::QT_fp16},
{"bf16", faiss::ScalarQuantizer::QT_bf16},
{"int8", faiss::ScalarQuantizer::QT_8bit_direct_signed}};

// todo: tolower
auto sq_type_tolower = str_to_lower(sq_type);
auto itr = sq_types.find(sq_type_tolower);
if (itr == sq_types.cend()) {
return expected<faiss::ScalarQuantizer::QuantizerType>::Err(
Status::invalid_args, fmt::format("invalid scalar quantizer type ({})", sq_type_tolower));
}

return itr->second;
}

/*
// checks whether an index contains a refiner, suitable for a given data format
std::optional<bool> whether_refine_is_datatype(
const faiss::Index* index,
const DataFormatEnum data_format
) {
if (index == nullptr) {
return {};
}

const faiss::IndexRefine* const index_refine = dynamic_cast<const faiss::IndexRefine*>(index);
if (index_refine == nullptr) {
return false;
}

switch(data_format) {
case DataFormatEnum::fp32:
return (dynamic_cast<const faiss::IndexFlat*>(index_refine->refine_index) != nullptr);
case DataFormatEnum::fp16:
{
const auto* const index_sq = dynamic_cast<const
faiss::IndexScalarQuantizer*>(index_refine->refine_index); return (index_sq != nullptr && index_sq->sq.qtype ==
faiss::ScalarQuantizer::QT_fp16);
}
case DataFormatEnum::bf16:
{
const auto* const index_sq = dynamic_cast<const
faiss::IndexScalarQuantizer*>(index_refine->refine_index); return (index_sq != nullptr && index_sq->sq.qtype ==
faiss::ScalarQuantizer::QT_bf16);
}
default:
return {};
}
}
*/

expected<bool>
is_flat_refine(const std::optional<std::string>& refine_type) {
// grab a type of a refine index
if (!refine_type.has_value()) {
return true;
};

// todo: tolower
std::string refine_type_tolower = str_to_lower(refine_type.value());
if (refine_type_tolower == "fp32" || refine_type_tolower == "flat") {
return true;
};

// parse
auto refine_sq_type = get_sq_quantizer_type(refine_type_tolower);
if (!refine_sq_type.has_value()) {
LOG_KNOWHERE_ERROR_ << "Invalid refine type: " << refine_type.value();
return expected<bool>::Err(Status::invalid_args, fmt::format("invalid refine type ({})", refine_type.value()));
}

return false;
}

bool
has_lossless_quant(const expected<faiss::ScalarQuantizer::QuantizerType>& quant_type, DataFormatEnum dataFormat) {
if (!quant_type.has_value()) {
return false;
}

auto quant = quant_type.value();
switch (dataFormat) {
case DataFormatEnum::fp32:
return false;
case DataFormatEnum::fp16:
return quant == faiss::ScalarQuantizer::QuantizerType::QT_fp16;
case DataFormatEnum::bf16:
return quant == faiss::ScalarQuantizer::QuantizerType::QT_bf16;
case DataFormatEnum::int8:
return quant == faiss::ScalarQuantizer::QuantizerType::QT_8bit_direct_signed;
default:
return false;
}
}

bool
has_lossless_refine_index(const FaissHnswConfig& hnsw_cfg, DataFormatEnum dataFormat) {
bool has_refine = hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value();
if (has_refine) {
expected<bool> flat_refine = is_flat_refine(hnsw_cfg.refine_type);
if (flat_refine.has_value() && flat_refine.value()) {
return true;
}

auto sq_refine_type = get_sq_quantizer_type(hnsw_cfg.refine_type.value());
return has_lossless_quant(sq_refine_type, dataFormat);
}
return false;
}

// pick a refine index
expected<std::unique_ptr<faiss::Index>>
pick_refine_index(const DataFormatEnum data_format, const std::optional<std::string>& refine_type,
std::unique_ptr<faiss::IndexHNSW>&& hnsw_index) {
// yes

// grab a type of a refine index
expected<bool> is_fp32_flat = is_flat_refine(refine_type);
if (!is_fp32_flat.has_value()) {
return expected<std::unique_ptr<faiss::Index>>::Err(Status::invalid_args, "");
}

const bool is_fp32_flat_v = is_fp32_flat.value();

// check input data_format
if (data_format == DataFormatEnum::fp16) {
// make sure that we're using fp16 refine
auto refine_sq_type = get_sq_quantizer_type(refine_type.value());
if (!(refine_sq_type.has_value() &&
(refine_sq_type.value() != faiss::ScalarQuantizer::QT_bf16 && !is_fp32_flat_v))) {
LOG_KNOWHERE_ERROR_ << "fp16 input data does not accept bf16 or fp32 as a refine index.";
return expected<std::unique_ptr<faiss::Index>>::Err(
Status::invalid_args, "fp16 input data does not accept bf16 or fp32 as a refine index.");
}
}

if (data_format == DataFormatEnum::bf16) {
// make sure that we're using bf16 refine
auto refine_sq_type = get_sq_quantizer_type(refine_type.value());
if (!(refine_sq_type.has_value() &&
(refine_sq_type.value() != faiss::ScalarQuantizer::QT_fp16 && !is_fp32_flat_v))) {
LOG_KNOWHERE_ERROR_ << "bf16 input data does not accept fp16 or fp32 as a refine index.";
return expected<std::unique_ptr<faiss::Index>>::Err(
Status::invalid_args, "bf16 input data does not accept fp16 or fp32 as a refine index.");
}
}

// build
std::unique_ptr<faiss::IndexHNSW> local_hnsw_index = std::move(hnsw_index);

// either build flat or sq
if (is_fp32_flat_v) {
// build IndexFlat as a refine
auto refine_index = std::make_unique<faiss::IndexRefineFlat>(local_hnsw_index.get());

// let refine_index to own everything
refine_index->own_fields = true;
local_hnsw_index.release();

// reassign
return refine_index;
} else {
// being IndexScalarQuantizer as a refine
auto refine_sq_type = get_sq_quantizer_type(refine_type.value());

// a redundant check
if (!refine_sq_type.has_value()) {
LOG_KNOWHERE_ERROR_ << "Invalid refine type: " << refine_type.value();
return expected<std::unique_ptr<faiss::Index>>::Err(
Status::invalid_args, fmt::format("invalid refine type ({})", refine_type.value()));
}

// create an sq
auto sq_refine = std::make_unique<faiss::IndexScalarQuantizer>(
local_hnsw_index->storage->d, refine_sq_type.value(), local_hnsw_index->storage->metric_type);

auto refine_index = std::make_unique<faiss::IndexRefine>(local_hnsw_index.get(), sq_refine.get());

// let refine_index to own everything
refine_index->own_refine_index = true;
refine_index->own_fields = true;
local_hnsw_index.release();
sq_refine.release();

// reassign
return refine_index;
}
}

} // namespace

//
class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode {
public:
Expand Down Expand Up @@ -2300,7 +2102,10 @@ class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode {

if (hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value()) {
// yes
auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index));
const auto hnsw_d = hnsw_index->storage->d;
const auto hnsw_metric_type = hnsw_index->storage->metric_type;
auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index),
hnsw_d, hnsw_metric_type);
if (!final_index_cnd.has_value()) {
return Status::invalid_args;
}
Expand Down Expand Up @@ -2368,7 +2173,7 @@ class BaseFaissRegularIndexHNSWSQNodeTemplate : public BaseFaissRegularIndexHNSW
return true;
}

return has_lossless_refine_index(hnsw_sq_cfg, datatype_v<DataType>);
return has_lossless_refine_index(hnsw_sq_cfg.refine, hnsw_sq_cfg.refine_type, datatype_v<DataType>);
}
};

Expand Down Expand Up @@ -2449,7 +2254,10 @@ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode {
std::unique_ptr<faiss::Index> final_index;
if (hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value()) {
// yes
auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index));
const auto hnsw_d = hnsw_index->storage->d;
const auto hnsw_metric_type = hnsw_index->storage->metric_type;
auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index),
hnsw_d, hnsw_metric_type);
if (!final_index_cnd.has_value()) {
return Status::invalid_args;
}
Expand Down Expand Up @@ -2640,7 +2448,7 @@ class BaseFaissRegularIndexHNSWPQNodeTemplate : public BaseFaissRegularIndexHNSW
static bool
StaticHasRawData(const knowhere::BaseConfig& config, const IndexVersion& version) {
auto hnsw_cfg = static_cast<const FaissHnswConfig&>(config);
return has_lossless_refine_index(hnsw_cfg, datatype_v<DataType>);
return has_lossless_refine_index(hnsw_cfg.refine, hnsw_cfg.refine_type, datatype_v<DataType>);
}
};

Expand Down Expand Up @@ -2728,7 +2536,10 @@ class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode {
std::unique_ptr<faiss::Index> final_index;
if (hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value()) {
// yes
auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index));
const auto hnsw_d = hnsw_index->storage->d;
const auto hnsw_metric_type = hnsw_index->storage->metric_type;
auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index),
hnsw_d, hnsw_metric_type);
if (!final_index_cnd.has_value()) {
return Status::invalid_args;
}
Expand Down Expand Up @@ -2920,7 +2731,7 @@ class BaseFaissRegularIndexHNSWPRQNodeTemplate : public BaseFaissRegularIndexHNS
static bool
StaticHasRawData(const knowhere::BaseConfig& config, const IndexVersion& version) {
auto hnsw_cfg = static_cast<const FaissHnswConfig&>(config);
return has_lossless_refine_index(hnsw_cfg, datatype_v<DataType>);
return has_lossless_refine_index(hnsw_cfg.refine, hnsw_cfg.refine_type, datatype_v<DataType>);
}
};

Expand Down
58 changes: 53 additions & 5 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include "faiss/IndexIVFPQFastScan.h"
#include "faiss/IndexIVFRaBitQ.h"
#include "faiss/IndexIVFScalarQuantizerCC.h"
#include "faiss/IndexPreTransform.h"
#include "faiss/IndexScaNN.h"
#include "faiss/IndexScalarQuantizer.h"
#include "faiss/VectorTransform.h"
Expand Down Expand Up @@ -670,9 +669,15 @@ IvfIndexNode<DataType, IndexType>::TrainInternal(const DataSetPtr dataset, std::
if constexpr (std::is_same<IndexIVFRaBitQWrapper, IndexType>::value) {
const IvfRaBitQConfig& ivf_rabitq_cfg = static_cast<const IvfRaBitQConfig&>(*cfg);
auto nlist = MatchNlist(rows, ivf_rabitq_cfg.nlist.value());
auto qb = ivf_rabitq_cfg.rbq_bits_query.value();

index = std::make_unique<IndexIVFRaBitQWrapper>(dim, nlist, qb, metric.value());
DataFormatEnum data_format = DataType2EnumHelper<DataType>::value;

auto result = IndexIVFRaBitQWrapper::create(dim, nlist, ivf_rabitq_cfg, data_format, metric.value());
if (!result.has_value()) {
return result.error();
}

index = std::move(result.value());
index->train(rows, (const float*)data);
}
index_ = std::move(index);
Expand Down Expand Up @@ -835,13 +840,36 @@ IvfIndexNode<DataType, IndexType>::Search(const DataSetPtr dataset, std::unique_

const IvfRaBitQConfig& ivf_rabitq_cfg = static_cast<const IvfRaBitQConfig&>(*cfg);

// use refine?
bool use_refine = false;

const bool whether_to_enable_refine = ivf_rabitq_cfg.refine_k.has_value();
if (const auto wrapper_index = dynamic_cast<const IndexIVFRaBitQWrapper*>(index_.get());
wrapper_index != nullptr) {
const faiss::IndexRefine* refine_index = wrapper_index->get_refine_index();
use_refine = (refine_index != nullptr);
}

faiss::IVFRaBitQSearchParameters ivf_search_params;
ivf_search_params.nprobe = nprobe;
ivf_search_params.max_codes = 0;
ivf_search_params.sel = id_selector;
ivf_search_params.qb = ivf_rabitq_cfg.rbq_bits_query.value_or(0);

index_->search(1, cur_query, k, distances.get() + offset, ids.get() + offset, &ivf_search_params);
if (use_refine && whether_to_enable_refine) {
// yes, use refine
faiss::IndexRefineSearchParameters refine_search_params;
refine_search_params.sel = id_selector;
refine_search_params.k_factor = ivf_rabitq_cfg.refine_k.value_or(1);
refine_search_params.base_index_params = &ivf_search_params;

index_->search(1, cur_query, k, distances.get() + offset, ids.get() + offset,
&refine_search_params);
} else {
// do not use refine
index_->search(1, cur_query, k, distances.get() + offset, ids.get() + offset,
&ivf_search_params);
}
} else {
auto cur_query = (const float*)data + index * dim;
if (is_cosine) {
Expand Down Expand Up @@ -964,7 +992,27 @@ IvfIndexNode<DataType, IndexType>::RangeSearch(const DataSetPtr dataset, std::un
ivf_search_params.sel = id_selector;
ivf_search_params.qb = ivf_rabitq_cfg.rbq_bits_query.value_or(0);

index_->range_search(1, cur_query, radius, &res, &ivf_search_params);
// use refine?
bool use_refine = false;

const bool whether_to_enable_refine = ivf_rabitq_cfg.refine_k.has_value();
if (const auto wrapper_index = dynamic_cast<const IndexIVFRaBitQWrapper*>(index_.get());
wrapper_index != nullptr) {
const faiss::IndexRefine* refine_index = wrapper_index->get_refine_index();
use_refine = (refine_index != nullptr);
}

if (use_refine && whether_to_enable_refine) {
// yes, use refine
faiss::IndexRefineSearchParameters refine_search_params;
refine_search_params.sel = id_selector;
refine_search_params.k_factor = ivf_rabitq_cfg.refine_k.value_or(1);
refine_search_params.base_index_params = &ivf_search_params;

index_->range_search(1, cur_query, radius, &res, &refine_search_params);
} else {
index_->range_search(1, cur_query, radius, &res, &ivf_search_params);
}
} else {
auto cur_query = (const float*)xq + index * dim;
if (is_cosine) {
Expand Down
Loading
Loading