Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 33 additions & 8 deletions src/index/hnsw/faiss_hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -748,8 +748,13 @@ class FaissHnswIterator : public IndexIterator {
FaissHnswIterator(const std::shared_ptr<faiss::Index>& index_in,
const std::shared_ptr<std::vector<uint32_t>>& labels_in, std::unique_ptr<float[]>&& query_in,
const BitsetView& bitset_in, const int32_t ef_in, bool larger_is_closer,
const float refine_ratio = 0.5f, bool use_knowhere_search_pool = true)
: IndexIterator(larger_is_closer, use_knowhere_search_pool, refine_ratio), index{index_in}, labels{labels_in} {
const float refine_ratio = 0.5f, const std::vector<uint32_t>& label_to_internal_offset_in = {},
const uint32_t mv_base_offset_in = 0, bool use_knowhere_search_pool = true)
: IndexIterator(larger_is_closer, use_knowhere_search_pool, refine_ratio),
index{index_in},
labels{labels_in},
label_to_internal_offset(label_to_internal_offset_in),
mv_base_offset(mv_base_offset_in) {
workspace.accumulated_alpha =
(bitset_in.count() >= (index->ntotal * HnswSearchThresholds::kHnswSearchKnnBFFilterThreshold))
? std::numeric_limits<float>::max()
Expand Down Expand Up @@ -989,13 +994,18 @@ class FaissHnswIterator : public IndexIterator {

float
raw_distance(int64_t id) override {
const float refined_distance = workspace.qdis_refine->operator()(id);
return refined_distance;
if (label_to_internal_offset.empty()) {
return workspace.qdis_refine->operator()(id);
}
auto mv_internal_offset = label_to_internal_offset[id] - mv_base_offset;
return workspace.qdis_refine->operator()(mv_internal_offset);
}

private:
std::shared_ptr<faiss::Index> index;
std::shared_ptr<std::vector<uint32_t>> labels;
const std::vector<uint32_t>& label_to_internal_offset; // internal_offset = label_to_internal_offset[label_id];
const uint32_t mv_base_offset; // mv_internal_offset = internal_offset - mv_base_offset;

FaissHnswIteratorWorkspace workspace;
};
Expand Down Expand Up @@ -1328,6 +1338,10 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {

expected<DataSetPtr>
RangeSearch(const DataSetPtr dataset, std::unique_ptr<Config> cfg, const BitsetView& bitset) const override {
// if support ann_iterator, use iterator-based range_search (IndexNode::RangeSearch)
if (is_ann_iterator_supported()) {
return IndexNode::RangeSearch(dataset, std::move(cfg), bitset);
}
if (this->indexes.empty()) {
return expected<DataSetPtr>::Err(Status::empty_index, "index not loaded");
}
Expand Down Expand Up @@ -1637,7 +1651,15 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {
}

public:
//
bool
is_ann_iterator_supported() const {
if (data_format != DataFormatEnum::fp32 && data_format != DataFormatEnum::fp16 &&
data_format != DataFormatEnum::bf16) {
return false;
}
return true;
}

expected<std::vector<IndexNode::IteratorPtr>>
AnnIterator(const DataSetPtr dataset, std::unique_ptr<Config> cfg, const BitsetView& bitset,
bool use_knowhere_search_pool) const override {
Expand All @@ -1646,8 +1668,7 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {
return expected<std::vector<IndexNode::IteratorPtr>>::Err(Status::empty_index, "index not loaded");
}

if (data_format != DataFormatEnum::fp32 && data_format != DataFormatEnum::fp16 &&
data_format != DataFormatEnum::bf16) {
if (!is_ann_iterator_supported()) {
LOG_KNOWHERE_ERROR_ << "Unsupported data format";
return expected<std::vector<IndexNode::IteratorPtr>>::Err(Status::invalid_args, "unsupported data format");
}
Expand Down Expand Up @@ -1698,9 +1719,13 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {
// create an iterator and initialize it
// refine is not needed for flat
// hnsw_cfg.iterator_refine_ratio.value_or(0.5f)

uint32_t mv_base_offset = index_rows_sum.size() > index_id ? index_rows_sum[index_id] : 0;

auto it = std::make_shared<FaissHnswIterator>(
indexes[index_id], labels.empty() ? nullptr : labels[index_id], std::move(cur_query), bitset, ef,
larger_is_closer, iterator_refine_ratio, use_knowhere_search_pool);
larger_is_closer, iterator_refine_ratio, label_to_internal_offset, mv_base_offset,
use_knowhere_search_pool);
// store
vec[i] = it;
}
Expand Down
Loading