Skip to content

Commit 2b56087

Browse files
authored
adapt to emblist strategy (#1564)
Signed-off-by: SpadeA <tangchenjie1210@gmail.com>
1 parent 33c5782 commit 2b56087

8 files changed

Lines changed: 704 additions & 49 deletions

File tree

include/knowhere/config.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -849,7 +849,7 @@ class BaseConfig : public Config {
849849
KNOWHERE_CONFIG_DECLARE_FIELD(lemur_hidden_dim)
850850
.description("Hidden dimension for LEMUR MLP (compressed representation dimension)")
851851
.set_default(256)
852-
.set_range(32, 4096)
852+
.set_range(8, 8192)
853853
.for_train();
854854
KNOWHERE_CONFIG_DECLARE_FIELD(lemur_num_train_samples)
855855
.description("Number of training samples for LEMUR MLP")

include/knowhere/emb_list_utils.h

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "knowhere/bitsetview.h"
2222
#include "knowhere/log.h"
2323
#include "knowhere/object.h"
24+
#include "knowhere/utils.h"
2425

2526
namespace knowhere {
2627

@@ -199,7 +200,7 @@ using EmbListAggFunc = std::function<std::optional<float>(const float*, size_t,
199200

200201
inline std::optional<EmbListAggFunc>
201202
get_emb_list_agg_func(const std::string& el_metric_type) {
202-
if (el_metric_type == metric::MAX_SIM) {
203+
if (IsMetricType(el_metric_type, metric::MAX_SIM)) {
203204
return get_sum_max_sim;
204205
}
205206
return nullptr;
@@ -212,13 +213,13 @@ get_emb_list_agg_func(const std::string& el_metric_type) {
212213
*/
213214
inline std::optional<std::string>
214215
get_el_metric_type(const std::string& metric_type) {
215-
if (metric_type == metric::MAX_SIM || metric_type == metric::MAX_SIM_IP || metric_type == metric::MAX_SIM_L2 ||
216-
metric_type == metric::MAX_SIM_COSINE || metric_type == metric::MAX_SIM_HAMMING ||
217-
metric_type == metric::MAX_SIM_JACCARD) {
216+
if (IsMetricType(metric_type, metric::MAX_SIM) || IsMetricType(metric_type, metric::MAX_SIM_IP) ||
217+
IsMetricType(metric_type, metric::MAX_SIM_L2) || IsMetricType(metric_type, metric::MAX_SIM_COSINE) ||
218+
IsMetricType(metric_type, metric::MAX_SIM_HAMMING) || IsMetricType(metric_type, metric::MAX_SIM_JACCARD)) {
218219
return metric::MAX_SIM;
219-
} else if (metric_type == metric::DTW || metric_type == metric::DTW_IP || metric_type == metric::DTW_L2 ||
220-
metric_type == metric::DTW_COSINE || metric_type == metric::DTW_HAMMING ||
221-
metric_type == metric::DTW_JACCARD) {
220+
} else if (IsMetricType(metric_type, metric::DTW) || IsMetricType(metric_type, metric::DTW_IP) ||
221+
IsMetricType(metric_type, metric::DTW_L2) || IsMetricType(metric_type, metric::DTW_COSINE) ||
222+
IsMetricType(metric_type, metric::DTW_HAMMING) || IsMetricType(metric_type, metric::DTW_JACCARD)) {
222223
return metric::DTW;
223224
}
224225
return std::nullopt;
@@ -231,20 +232,20 @@ get_el_metric_type(const std::string& metric_type) {
231232
*/
232233
inline std::optional<std::string>
233234
get_sub_metric_type(const std::string& metric_type) {
234-
if (metric_type == metric::MAX_SIM_COSINE || metric_type == metric::MAX_SIM || metric_type == metric::DTW_COSINE ||
235-
metric_type == metric::DTW) {
235+
if (IsMetricType(metric_type, metric::MAX_SIM_COSINE) || IsMetricType(metric_type, metric::MAX_SIM) ||
236+
IsMetricType(metric_type, metric::DTW_COSINE) || IsMetricType(metric_type, metric::DTW)) {
236237
return metric::COSINE;
237238
}
238-
if (metric_type == metric::MAX_SIM_IP || metric_type == metric::DTW_IP) {
239+
if (IsMetricType(metric_type, metric::MAX_SIM_IP) || IsMetricType(metric_type, metric::DTW_IP)) {
239240
return metric::IP;
240241
}
241-
if (metric_type == metric::MAX_SIM_L2 || metric_type == metric::DTW_L2) {
242+
if (IsMetricType(metric_type, metric::MAX_SIM_L2) || IsMetricType(metric_type, metric::DTW_L2)) {
242243
return metric::L2;
243244
}
244-
if (metric_type == metric::MAX_SIM_HAMMING || metric_type == metric::DTW_HAMMING) {
245+
if (IsMetricType(metric_type, metric::MAX_SIM_HAMMING) || IsMetricType(metric_type, metric::DTW_HAMMING)) {
245246
return metric::HAMMING;
246247
}
247-
if (metric_type == metric::MAX_SIM_JACCARD || metric_type == metric::DTW_JACCARD) {
248+
if (IsMetricType(metric_type, metric::MAX_SIM_JACCARD) || IsMetricType(metric_type, metric::DTW_JACCARD)) {
248249
return metric::JACCARD;
249250
}
250251
return std::nullopt;

src/index/diskann/diskann.cc

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,25 @@ class DiskANNIndexNode : public IndexNode {
104104
static bool
105105
StaticHasRawData(const knowhere::BaseConfig& config, const IndexVersion& version) {
106106
knowhere::MetricType metric_type = config.metric_type.has_value() ? config.metric_type.value() : "";
107-
return IsMetricType(metric_type, metric::L2) || IsMetricType(metric_type, metric::COSINE);
107+
const auto& base_metric = get_sub_metric_type(metric_type).value_or(metric_type);
108+
return IsMetricType(base_metric, metric::L2) || IsMetricType(base_metric, metric::COSINE);
109+
}
110+
111+
static Status
112+
StaticConfigCheck(const Config& cfg, PARAM_TYPE paramType, std::string& msg) {
113+
auto& base_cfg = static_cast<const BaseConfig&>(cfg);
114+
auto strategy = base_cfg.emb_list_strategy.value_or("");
115+
if (strategy == meta::EMB_LIST_STRATEGY_MUVERA || strategy == meta::EMB_LIST_STRATEGY_LEMUR) {
116+
msg = "DiskANN only supports TokenANN strategy, got '" + strategy + "'";
117+
return Status::invalid_args;
118+
}
119+
return Status::success;
108120
}
109121

110122
bool
111123
HasRawData(const std::string& metric_type) const override {
112-
return IsMetricType(metric_type, metric::L2) || IsMetricType(metric_type, metric::COSINE);
124+
const auto& base_metric = get_sub_metric_type(metric_type).value_or(metric_type);
125+
return IsMetricType(base_metric, metric::L2) || IsMetricType(base_metric, metric::COSINE);
113126
}
114127

115128
expected<DataSetPtr>
@@ -497,6 +510,13 @@ DiskANNIndexNode<DataType>::BuildEmbListIfNeed(const DataSetPtr dataset, std::sh
497510
return Build(dataset, std::move(cfg), use_knowhere_build_pool);
498511
}
499512

513+
// DiskANN only supports TokenANN strategy
514+
auto strategy_type = config.emb_list_strategy.value_or(meta::EMB_LIST_STRATEGY_TOKENANN);
515+
if (strategy_type != meta::EMB_LIST_STRATEGY_TOKENANN) {
516+
LOG_KNOWHERE_ERROR_ << "DiskANN only supports TokenANN strategy, got: " << strategy_type;
517+
return Status::invalid_args;
518+
}
519+
500520
LOG_KNOWHERE_INFO_ << "Build emb_list index and read emb_list offset from file.";
501521

502522
// Validate and get the emb_list offset file path

src/index/index_node.cc

Lines changed: 58 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -287,25 +287,24 @@ IndexNode::GetEmbListByIds(const DataSetPtr dataset, const std::string& metric_t
287287
"GetEmbListByIds requires emb_list_offset, but it is not available");
288288
}
289289
auto sub_metric = get_sub_metric_type(metric_type);
290-
if (!sub_metric.has_value() || !HasRawData(sub_metric.value())) {
290+
if (!sub_metric.has_value()) {
291+
return expected<DataSetPtr>::Err(Status::not_implemented,
292+
"GetEmbListByIds: invalid metric type " + metric_type);
293+
}
294+
295+
// Raw data can come from emb_list_raw_index_ (MUVERA/LEMUR) or base index (TokenANN)
296+
bool use_raw_index = (emb_list_raw_index_ != nullptr);
297+
if (!use_raw_index && !HasRawData(sub_metric.value())) {
291298
return expected<DataSetPtr>::Err(
292299
Status::not_implemented,
293300
"GetEmbListByIds requires raw data support, but the index does not store raw vectors");
294301
}
295302

296303
auto num_el_ids = dataset->GetRows();
297304
auto el_ids = dataset->GetIds();
298-
auto dim = Dim();
299-
300-
// Build the output offset array and collect all vector-level IDs in a single pass.
301-
//
302-
// TODO(perf): Vectors within each embedding list are contiguous in the index. However, the current
303-
// implementation collects all these contiguous IDs into a flat array and passes them to GetVectorByIds,
304-
// which internally calls reconstruct(id, ...) one vector at a time. This could be optimized by using
305-
// reconstruct_n(start, len, ...) or direct memcpy from raw data storage, avoiding both the redundant
306-
// ID array allocation and per-vector overhead. We don't do this yet because it would require
307-
// index-type-specific implementations (HNSW, IVF, FLAT, etc. each store raw data differently),
308-
// whereas the current approach works generically across all index types via the GetVectorByIds interface.
305+
auto dim = use_raw_index ? emb_list_raw_index_->d : Dim();
306+
307+
// Build the output offset array
309308
std::vector<size_t> out_offsets(num_el_ids + 1);
310309
out_offsets[0] = 0;
311310
for (int64_t i = 0; i < num_el_ids; i++) {
@@ -318,17 +317,9 @@ IndexNode::GetEmbListByIds(const DataSetPtr dataset, const std::string& metric_t
318317
out_offsets[i + 1] = out_offsets[i] + emb_list_offset_->get_el_len(el_id);
319318
}
320319

321-
std::vector<int64_t> vec_ids;
322-
vec_ids.reserve(out_offsets[num_el_ids]);
323-
for (int64_t i = 0; i < num_el_ids; i++) {
324-
size_t start = emb_list_offset_->offset[el_ids[i]];
325-
size_t len = out_offsets[i + 1] - out_offsets[i];
326-
for (size_t j = 0; j < len; j++) {
327-
vec_ids.push_back(static_cast<int64_t>(start + j));
328-
}
329-
}
320+
auto total_vecs = out_offsets[num_el_ids];
330321

331-
if (vec_ids.empty()) {
322+
if (total_vecs == 0) {
332323
// all emblist are empty list
333324
auto result = GenResultDataSet(num_el_ids, dim, (const void*)nullptr);
334325
auto* offsets_ptr = new size_t[out_offsets.size()];
@@ -337,16 +328,52 @@ IndexNode::GetEmbListByIds(const DataSetPtr dataset, const std::string& metric_t
337328
return result;
338329
}
339330

340-
auto vec_dataset = GenIdsDataSet(vec_ids.size(), vec_ids.data());
341-
auto res = GetVectorByIds(vec_dataset, op_context);
342-
if (!res.has_value()) {
343-
return res;
344-
}
331+
const void* tensor = nullptr;
332+
333+
if (use_raw_index) {
334+
// MUVERA/LEMUR: vectors are contiguous per el in emb_list_raw_index_, use reconstruct_n
335+
auto data = std::make_unique<float[]>(total_vecs * dim);
336+
float* ptr = data.get();
337+
for (int64_t i = 0; i < num_el_ids; i++) {
338+
auto start = static_cast<int64_t>(emb_list_offset_->offset[el_ids[i]]);
339+
auto len = static_cast<int64_t>(out_offsets[i + 1] - out_offsets[i]);
340+
if (len > 0) {
341+
emb_list_raw_index_->reconstruct_n(start, len, ptr);
342+
ptr += len * dim;
343+
}
344+
}
345+
tensor = data.release();
346+
} else {
347+
// TokenANN: collect vec_ids and use base index GetVectorByIds
348+
//
349+
// TODO(perf): Vectors within each embedding list are contiguous in the index. However, the current
350+
// implementation collects all these contiguous IDs into a flat array and passes them to GetVectorByIds,
351+
// which internally calls reconstruct(id, ...) one vector at a time. This could be optimized by using
352+
// reconstruct_n(start, len, ...) or direct memcpy from raw data storage, avoiding both the redundant
353+
// ID array allocation and per-vector overhead. We don't do this yet because it would require
354+
// index-type-specific implementations (HNSW, IVF, FLAT, etc. each store raw data differently),
355+
// whereas the current approach works generically across all index types via the GetVectorByIds interface.
356+
std::vector<int64_t> vec_ids;
357+
vec_ids.reserve(total_vecs);
358+
for (int64_t i = 0; i < num_el_ids; i++) {
359+
size_t start = emb_list_offset_->offset[el_ids[i]];
360+
size_t len = out_offsets[i + 1] - out_offsets[i];
361+
for (size_t j = 0; j < len; j++) {
362+
vec_ids.push_back(static_cast<int64_t>(start + j));
363+
}
364+
}
365+
366+
// Build result: transfer tensor ownership from GetVectorByIds result to new dataset
367+
auto vec_dataset = GenIdsDataSet(vec_ids.size(), vec_ids.data());
368+
auto res = GetVectorByIds(vec_dataset, op_context);
369+
if (!res.has_value()) {
370+
return res;
371+
}
345372

346-
// Build result: transfer tensor ownership from GetVectorByIds result to new dataset
347-
auto vec_result = res.value();
348-
auto tensor = vec_result->GetTensor();
349-
vec_result->SetIsOwner(false);
373+
auto vec_result = res.value();
374+
tensor = vec_result->GetTensor();
375+
vec_result->SetIsOwner(false);
376+
}
350377

351378
auto result = GenResultDataSet(num_el_ids, dim, tensor);
352379
auto* offsets_ptr = new size_t[out_offsets.size()];

src/index/index_static.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ IndexStaticFaced<DataType>::ConfigCheck(const IndexType& indexType, const IndexV
5656
return status;
5757
}
5858

59+
if constexpr (!std::is_same_v<DataType, knowhere::fp32>) {
60+
auto strategy = cfg->emb_list_strategy.value_or("");
61+
if (strategy == meta::EMB_LIST_STRATEGY_MUVERA || strategy == meta::EMB_LIST_STRATEGY_LEMUR) {
62+
msg = "MUVERA/LEMUR strategies only support fp32 data type, got '" + strategy + "'";
63+
return Status::invalid_args;
64+
}
65+
}
66+
5967
if (Instance().staticConfigCheckMap.find(indexType) != Instance().staticConfigCheckMap.end()) {
6068
return Instance().staticConfigCheckMap[indexType](*cfg, knowhere::PARAM_TYPE::TRAIN, msg);
6169
}

0 commit comments

Comments
 (0)