|
18 | 18 | #include "faiss/utils/binary_distances.h" |
19 | 19 | #include "faiss/utils/distances.h" |
20 | 20 | #include "faiss/utils/distances_typed.h" |
| 21 | +#include "index/minhash/minhash_util.h" |
21 | 22 | #include "knowhere/bitsetview_idselector.h" |
22 | 23 | #include "knowhere/comp/thread_pool.h" |
23 | 24 | #include "knowhere/config.h" |
@@ -151,6 +152,14 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset |
151 | 152 | auto labels = std::make_unique<int64_t[]>(nq * topk); |
152 | 153 | auto distances = std::make_unique<float[]>(nq * topk); |
153 | 154 | std::unique_ptr<float[]> norms = is_cosine ? GetVecNorms<DataType>(base_dataset) : nullptr; |
| 155 | + // some check for minhash metric |
| 156 | + if (faiss_metric_type == faiss::METRIC_MinHash_Jaccard) { |
| 157 | + auto mh_valid_stat = |
| 158 | + MinhashConfigCheck(dim, datatype_v<DataType>, PARAM_TYPE::SEARCH | PARAM_TYPE::TRAIN, &cfg, &bitset); |
| 159 | + if (mh_valid_stat != Status::success) { |
| 160 | + return expected<DataSetPtr>::Err(mh_valid_stat, "MinhashConfigCheck() failed, please check the config."); |
| 161 | + } |
| 162 | + } |
154 | 163 | auto pool = ThreadPool::GetGlobalSearchThreadPool(); |
155 | 164 | std::vector<folly::Future<Status>> futs; |
156 | 165 | futs.reserve(nq); |
@@ -213,6 +222,23 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset |
213 | 222 | binary_knn_hc(faiss::METRIC_Jaccard, &res, cur_query, (const uint8_t*)xb, nb, dim / 8, id_selector); |
214 | 223 | break; |
215 | 224 | } |
| 225 | + case faiss::METRIC_MinHash_Jaccard: { |
| 226 | + size_t band = cfg.band.value(); |
| 227 | + bool search_with_jaccard = cfg.search_with_jaccard.value(); |
| 228 | + if (search_with_jaccard) { |
| 229 | + size_t hash_element_size = cfg.mh_element_bit_width.value() / 8; // in bytes |
| 230 | + size_t hash_element_length = dim / (hash_element_size * 8); |
| 231 | + auto cur_query = (const char*)xq + (dim / 8) * index; |
| 232 | + minhash_jaccard_knn_ny(cur_query, (const char*)xb, hash_element_length, hash_element_size, nb, |
| 233 | + topk, bitset, cur_distances, cur_labels); |
| 234 | + } else { |
| 235 | + size_t u8_dim = dim / 8; |
| 236 | + auto cur_query = (const char*)xq + u8_dim * index; |
| 237 | + minhash_lsh_hit_ny(cur_query, (const char*)xb, u8_dim, band, nb, topk, bitset, cur_distances, |
| 238 | + cur_labels); |
| 239 | + } |
| 240 | + break; |
| 241 | + } |
216 | 242 | case faiss::METRIC_Hamming: { |
217 | 243 | auto cur_query = (const uint8_t*)xq + (dim / 8) * index; |
218 | 244 | std::vector<int32_t> int_distances(topk); |
@@ -306,6 +332,14 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_ |
306 | 332 | int topk = cfg.k.value(); |
307 | 333 | auto labels = ids; |
308 | 334 | auto distances = dis; |
| 335 | + // some check for minhash metric |
| 336 | + if (faiss_metric_type == faiss::METRIC_MinHash_Jaccard) { |
| 337 | + auto mh_valid_stat = |
| 338 | + MinhashConfigCheck(dim, datatype_v<DataType>, PARAM_TYPE::SEARCH | PARAM_TYPE::TRAIN, &cfg, &bitset); |
| 339 | + if (mh_valid_stat != Status::success) { |
| 340 | + return mh_valid_stat; |
| 341 | + } |
| 342 | + } |
309 | 343 |
|
310 | 344 | std::unique_ptr<float[]> norms = is_cosine ? GetVecNorms<DataType>(base_dataset) : nullptr; |
311 | 345 | auto pool = ThreadPool::GetGlobalSearchThreadPool(); |
@@ -363,6 +397,23 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_ |
363 | 397 | } |
364 | 398 | break; |
365 | 399 | } |
| 400 | + case faiss::METRIC_MinHash_Jaccard: { |
| 401 | + size_t band = cfg.band.value(); |
| 402 | + bool search_with_jaccard = cfg.search_with_jaccard.value(); |
| 403 | + if (search_with_jaccard) { |
| 404 | + size_t hash_element_size = cfg.mh_element_bit_width.value() / 8; // in bytes |
| 405 | + size_t hash_element_length = dim / (hash_element_size * 8); |
| 406 | + auto cur_query = (const char*)xq + (dim / 8) * index; |
| 407 | + minhash_jaccard_knn_ny(cur_query, (const char*)xb, hash_element_length, hash_element_size, nb, |
| 408 | + topk, bitset, cur_distances, cur_labels); |
| 409 | + } else { |
| 410 | + size_t u8_dim = dim / 8; |
| 411 | + auto cur_query = (const char*)xq + u8_dim * index; |
| 412 | + minhash_lsh_hit_ny(cur_query, (const char*)xb, u8_dim, band, nb, topk, bitset, cur_distances, |
| 413 | + cur_labels); |
| 414 | + } |
| 415 | + break; |
| 416 | + } |
366 | 417 | case faiss::METRIC_Jaccard: { |
367 | 418 | auto cur_query = (const uint8_t*)xq + (dim / 8) * index; |
368 | 419 | faiss::float_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, cur_distances}; |
@@ -483,6 +534,10 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da |
483 | 534 | float range_filter = cfg.range_filter.value(); |
484 | 535 |
|
485 | 536 | auto pool = ThreadPool::GetGlobalSearchThreadPool(); |
| 537 | + // some check for minhash metric |
| 538 | + if (metric_str == metric::MHJACCARD) { |
| 539 | + return expected<DataSetPtr>::Err(Status::not_implemented, "minhash not support range search."); |
| 540 | + } |
486 | 541 |
|
487 | 542 | std::vector<std::vector<int64_t>> result_id_array(nq); |
488 | 543 | std::vector<std::vector<float>> result_dist_array(nq); |
@@ -758,6 +813,12 @@ BruteForce::AnnIterator(const DataSetPtr base_dataset, const DataSetPtr query_da |
758 | 813 | return expected<std::vector<IndexNode::IteratorPtr>>::Err(result.error(), result.what()); |
759 | 814 | } |
760 | 815 |
|
| 816 | + // some check for minhash metric |
| 817 | + if (metric_str == metric::MHJACCARD) { |
| 818 | + return expected<std::vector<IndexNode::IteratorPtr>>::Err(Status::not_implemented, |
| 819 | + "minhash does not support iterator."); |
| 820 | + } |
| 821 | + |
761 | 822 | #if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) |
762 | 823 | // LCOV_EXCL_START |
763 | 824 | std::shared_ptr<tracer::trace::Span> span = nullptr; |
|
0 commit comments