Skip to content

Commit fd0871f

Browse files
authored
enhance: knowhere support metric mhjaccard and index minhash lsh for minhash vector (#1203)
Signed-off-by: cqy123456 <qianya.cheng@zilliz.com>
1 parent 33e13f3 commit fd0871f

33 files changed

Lines changed: 2277 additions & 20 deletions

cmake/libs/libfaiss.cmake

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ if(__X86_64)
5050
${UTILS_SRC} $<TARGET_OBJECTS:utils_sse> $<TARGET_OBJECTS:utils_avx>
5151
$<TARGET_OBJECTS:utils_avx512> $<TARGET_OBJECTS:utils_avx512icx>)
5252
target_link_libraries(knowhere_utils PUBLIC glog::glog)
53+
target_link_libraries(knowhere_utils PUBLIC xxHash::xxhash)
5354
endif()
5455

5556
if(__AARCH64)
@@ -99,19 +100,22 @@ if(__AARCH64)
99100
endif()
100101

101102
target_link_libraries(knowhere_utils PUBLIC glog::glog)
103+
target_link_libraries(knowhere_utils PUBLIC xxHash::xxhash)
102104
endif()
103105

104106
if(__RISCV64)
105107
set(UTILS_SRC src/simd/hook.cc src/simd/distances_ref.cc)
106108
add_library(knowhere_utils STATIC ${UTILS_SRC})
107109
target_link_libraries(knowhere_utils PUBLIC glog::glog)
110+
target_link_libraries(knowhere_utils PUBLIC xxHash::xxhash)
108111
endif()
109112

110113
# ToDo: Add distances_vsx.cc for powerpc64 SIMD acceleration
111114
if(__PPC64)
112115
set(UTILS_SRC src/simd/hook.cc src/simd/distances_ref.cc src/simd/distances_powerpc.cc)
113116
add_library(knowhere_utils STATIC ${UTILS_SRC})
114117
target_link_libraries(knowhere_utils PUBLIC glog::glog)
118+
target_link_libraries(knowhere_utils PUBLIC xxHash::xxhash)
115119
endif()
116120

117121

@@ -131,6 +135,9 @@ else()
131135
find_package(BLAS REQUIRED)
132136
endif()
133137

138+
find_package(xxHash REQUIRED)
139+
include_directories(${xxHash_INCLUDE_DIRS})
140+
134141
if(__X86_64)
135142
list(REMOVE_ITEM FAISS_SRCS ${FAISS_AVX2_SRCS})
136143

conanfile.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def requirements(self):
106106
self.requires("folly/2023.10.30.09@milvus/dev")
107107
self.requires("libcurl/8.2.1")
108108
self.requires("simde/0.8.2")
109+
self.requires("xxhash/0.8.2")
109110
if self.settings.os == "Android":
110111
self.requires("openblas/0.3.27")
111112
if not self.options.with_light:
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
//
2+
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
3+
// with the License. You may obtain a copy of the License at
4+
//
5+
// http://www.apache.org/licenses/LICENSE-2.0
6+
//
7+
// Unless required by applicable law or agreed to in writing, software distributed under the License
8+
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
9+
// or implied. See the License for the specific language governing permissions and limitations under the License.
10+
#ifndef KNOWHERE_KNOWHERE_H
11+
#define KNOWHERE_KNOWHERE_H
12+
#include <cmath>
13+
#include <cstring>
14+
#include <fstream>
15+
#include <functional>
16+
#include <vector>
17+
18+
#include "io/memory_io.h"
19+
#include "knowhere/bitsetview.h"
20+
#include "knowhere/utils.h"
21+
namespace knowhere {
22+
template <typename T>
23+
class BloomFilter {
24+
public:
25+
BloomFilter() {
26+
}
27+
BloomFilter(size_t expected_elements, double false_positive_prob) : n(expected_elements), p(false_positive_prob) {
28+
m = static_cast<size_t>(-(n * log(p)) / (log(2) * log(2)));
29+
k = static_cast<int>(m / n * log(2));
30+
m = std::max<size_t>(m, 1);
31+
k = std::max(k, 1);
32+
bits.resize(m, false);
33+
}
34+
35+
void
36+
add(const T& element) {
37+
size_t glb_hash = hash((const char*)&element, sizeof(element), 0);
38+
for (int i = 0; i < k; ++i) {
39+
size_t pos = (glb_hash + i) % m;
40+
bits[pos] = true;
41+
}
42+
}
43+
44+
bool
45+
contains(const T& element) const {
46+
size_t glb_hash = hash((const char*)&element, sizeof(element), 0);
47+
for (int i = 0; i < k; ++i) {
48+
size_t pos = (glb_hash + i) % m;
49+
if (!bits[pos])
50+
return false;
51+
}
52+
return true;
53+
}
54+
55+
void
56+
save(MemoryIOWriter& writer) const {
57+
writeBinaryPOD(writer, m);
58+
writeBinaryPOD(writer, k);
59+
writeBinaryPOD(writer, n);
60+
writeBinaryPOD(writer, p);
61+
auto bytes_num = (m + 8 - 1) / 8;
62+
std::vector<char> buffer(bytes_num, 0);
63+
for (size_t i = 0; i < m; ++i) {
64+
if (bits[i]) {
65+
buffer[i / 8] |= (1 << (i % 8));
66+
}
67+
}
68+
writer.write(buffer.data(), buffer.size());
69+
}
70+
71+
void
72+
load(MemoryIOReader& reader) {
73+
readBinaryPOD(reader, m);
74+
readBinaryPOD(reader, k);
75+
readBinaryPOD(reader, n);
76+
readBinaryPOD(reader, p);
77+
bits.clear();
78+
bits.resize(m);
79+
auto bytes_num = (m + 8 - 1) / 8;
80+
std::vector<char> buffer(bytes_num);
81+
reader.read(buffer.data(), bytes_num);
82+
for (size_t i = 0; i < m; ++i) {
83+
bool bit = (buffer[i / 8] >> (i % 8)) & 1;
84+
bits.push_back(bit);
85+
}
86+
}
87+
size_t
88+
size() const {
89+
return n;
90+
}
91+
double
92+
false_positive_rate() const {
93+
return p;
94+
}
95+
size_t
96+
memory_usage() const {
97+
return m / 8;
98+
}
99+
100+
private:
101+
static constexpr size_t multiplier = 31;
102+
std::vector<bool> bits;
103+
size_t m = 0;
104+
int k = 0;
105+
double p = 0;
106+
size_t n = 0;
107+
108+
size_t
109+
hash(const char* data, size_t length, size_t bucket_i) const {
110+
if (data == nullptr) {
111+
throw std::runtime_error("can't hash null data.");
112+
}
113+
size_t result = 0;
114+
for (size_t i = 0; i < length; ++i) {
115+
result = (result * multiplier) + static_cast<size_t>(data[i]);
116+
}
117+
return (result + bucket_i) % m;
118+
}
119+
};
120+
} // namespace knowhere
121+
#endif

include/knowhere/comp/index_param.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ constexpr const char* INDEX_HNSW_PQ = "HNSW_PQ";
5656
constexpr const char* INDEX_HNSW_PRQ = "HNSW_PRQ";
5757

5858
constexpr const char* INDEX_DISKANN = "DISKANN";
59+
constexpr const char* INDEX_MINHASH_INDEX = "MINHASH_LSH";
5960

6061
constexpr const char* INDEX_SPARSE_INVERTED_INDEX = "SPARSE_INVERTED_INDEX";
6162
constexpr const char* INDEX_SPARSE_WAND = "SPARSE_WAND";
@@ -187,6 +188,16 @@ constexpr const char* DROP_RATIO_SEARCH = "drop_ratio_search";
187188
// RaBitQ Params
188189
constexpr const char* RABITQ_QUERY_BITS = "rbq_bits_query";
189190

191+
// minhash Params
192+
constexpr const char* ALIGNED_BLOCK_SIZE = "aligned_block_size";
193+
constexpr const char* BAND = "band";
194+
constexpr const char* SHARED_BLOOM_FILTER = "shared_bloom_filter";
195+
constexpr const char* BLOOM_FALSE_POSITIVE_RPOB = "bloom_false_positive_prob";
196+
constexpr const char* HASH_CODE_IN_MEM = "hash_code_in_mem";
197+
constexpr const char* SEARCH_WITH_JACCARD = "search_with_jaccard";
198+
constexpr const char* MH_ELEMENT_BIT_WIDTH = "mh_element_bit_width";
199+
constexpr const char* MH_LSH_REFINE_K = "refine_k";
200+
constexpr const char* BATCH_SEARCH = "batch_search";
190201
} // namespace indexparam
191202

192203
using MetricType = std::string;
@@ -197,6 +208,7 @@ constexpr const char* L2 = "L2";
197208
constexpr const char* COSINE = "COSINE";
198209
constexpr const char* HAMMING = "HAMMING";
199210
constexpr const char* JACCARD = "JACCARD";
211+
constexpr const char* MHJACCARD = "MHJACCARD";
200212
constexpr const char* SUBSTRUCTURE = "SUBSTRUCTURE";
201213
constexpr const char* SUPERSTRUCTURE = "SUPERSTRUCTURE";
202214
constexpr const char* BM25 = "BM25";

include/knowhere/config.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,12 @@ class BaseConfig : public Config {
633633
*/
634634
CFG_INT refine_type;
635635
CFG_BOOL refine_with_quant;
636+
/*
637+
* band is a special parameters of BF search and MinHash index node train.
638+
*/
639+
CFG_INT band;
640+
CFG_BOOL search_with_jaccard;
641+
CFG_INT mh_element_bit_width;
636642
KNOHWERE_DECLARE_CONFIG(BaseConfig) {
637643
KNOWHERE_CONFIG_DECLARE_FIELD(dim).allow_empty_without_default().description("vector dim").for_train();
638644
KNOWHERE_CONFIG_DECLARE_FIELD(metric_type)
@@ -781,6 +787,17 @@ class BaseConfig : public Config {
781787
.for_search()
782788
.for_range_search()
783789
.for_iterator();
790+
KNOWHERE_CONFIG_DECLARE_FIELD(band).description("param of MinHashLSH").set_default(1).for_train().for_search();
791+
KNOWHERE_CONFIG_DECLARE_FIELD(mh_element_bit_width)
792+
.description("sizeof(hash code), the hash element should be aligned on 8 bits")
793+
.set_default(8)
794+
.set_range(8, 256)
795+
.for_train()
796+
.for_search();
797+
KNOWHERE_CONFIG_DECLARE_FIELD(search_with_jaccard)
798+
.description("return the jaccard distance of minhash vector search or minhashlsh hit flag.")
799+
.set_default(false)
800+
.for_search();
784801
}
785802
};
786803
} // namespace knowhere

include/knowhere/index/index_factory.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,10 @@ class IndexFactory {
143143
#define KNOWHERE_MOCK_REGISTER_DENSE_INT_GLOBAL(name, index_node, features, ...) \
144144
KNOWHERE_MOCK_REGISTER_GLOBAL(name, index_node, int8, (features | knowhere::feature::INT8), ##__VA_ARGS__);
145145

146+
// register vector index supporting binary data types
147+
#define KNOWHERE_MOCK_REGISTER_DENSE_BINARY_ALL_GLOBAL(name, index_node, features, ...) \
148+
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bin1, (features | knowhere::feature::BINARY), ##__VA_ARGS__);
149+
146150
// register vector index supporting ALL_DENSE_FLOAT_TYPE(float32, bf16, fp16) data types, but mocked bf16 and fp16
147151
#define KNOWHERE_MOCK_REGISTER_DENSE_FLOAT_ALL_GLOBAL(name, index_node, features, ...) \
148152
KNOWHERE_MOCK_REGISTER_GLOBAL(name, index_node, bf16, (features | knowhere::feature::BF16), ##__VA_ARGS__); \

include/knowhere/index/index_table.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ static std::set<std::pair<std::string, VecType>> legal_knowhere_index = {
9898
// sparse index
9999
{IndexEnum::INDEX_SPARSE_INVERTED_INDEX, VecType::VECTOR_SPARSE_FLOAT},
100100
{IndexEnum::INDEX_SPARSE_WAND, VecType::VECTOR_SPARSE_FLOAT},
101+
// minhash index
102+
{IndexEnum::INDEX_MINHASH_INDEX, VecType::VECTOR_BINARY},
101103
};
102104

103105
static std::set<std::string> legal_support_mmap_knowhere_index = {

include/knowhere/utils.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <strings.h>
1515

1616
#include <algorithm>
17+
#include <fstream>
1718
#include <vector>
1819

1920
#include "knowhere/binaryset.h"
@@ -226,6 +227,25 @@ readBinaryPOD(R& in, T& podRef) {
226227
in.read((char*)&podRef, sizeof(T));
227228
}
228229

230+
inline void
231+
load_binary_vec(const std::string& bin_file, std::unique_ptr<char[]>& data, size_t& npts, size_t& dim) {
232+
std::ifstream file(bin_file, std::ios::binary);
233+
if (!file.is_open()) {
234+
throw std::runtime_error("fail to open file: " + bin_file);
235+
}
236+
uint32_t n, d;
237+
file.read(reinterpret_cast<char*>(&n), sizeof(uint32_t));
238+
file.read(reinterpret_cast<char*>(&d), sizeof(uint32_t));
239+
npts = n;
240+
dim = d;
241+
if (dim % 8 != 0) {
242+
throw std::runtime_error("fail to load binary vector base file, dim % 8 != 0 ");
243+
}
244+
uint64_t total_size = dim * npts / 8;
245+
data = std::make_unique<char[]>(total_size);
246+
file.read(reinterpret_cast<char*>(data.get()), total_size);
247+
}
248+
229249
// taken from
230250
// https://github.com/Microsoft/BLAS-on-flash/blob/master/include/utils.h
231251
// round up X to the nearest multiple of Y

src/common/comp/brute_force.cc

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "faiss/utils/binary_distances.h"
1919
#include "faiss/utils/distances.h"
2020
#include "faiss/utils/distances_typed.h"
21+
#include "index/minhash/minhash_util.h"
2122
#include "knowhere/bitsetview_idselector.h"
2223
#include "knowhere/comp/thread_pool.h"
2324
#include "knowhere/config.h"
@@ -151,6 +152,14 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
151152
auto labels = std::make_unique<int64_t[]>(nq * topk);
152153
auto distances = std::make_unique<float[]>(nq * topk);
153154
std::unique_ptr<float[]> norms = is_cosine ? GetVecNorms<DataType>(base_dataset) : nullptr;
155+
// some check for minhash metric
156+
if (faiss_metric_type == faiss::METRIC_MinHash_Jaccard) {
157+
auto mh_valid_stat =
158+
MinhashConfigCheck(dim, datatype_v<DataType>, PARAM_TYPE::SEARCH | PARAM_TYPE::TRAIN, &cfg, &bitset);
159+
if (mh_valid_stat != Status::success) {
160+
return expected<DataSetPtr>::Err(mh_valid_stat, "MinhashConfigCheck() failed, please check the config.");
161+
}
162+
}
154163
auto pool = ThreadPool::GetGlobalSearchThreadPool();
155164
std::vector<folly::Future<Status>> futs;
156165
futs.reserve(nq);
@@ -213,6 +222,23 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
213222
binary_knn_hc(faiss::METRIC_Jaccard, &res, cur_query, (const uint8_t*)xb, nb, dim / 8, id_selector);
214223
break;
215224
}
225+
case faiss::METRIC_MinHash_Jaccard: {
226+
size_t band = cfg.band.value();
227+
bool search_with_jaccard = cfg.search_with_jaccard.value();
228+
if (search_with_jaccard) {
229+
size_t hash_element_size = cfg.mh_element_bit_width.value() / 8; // in bytes
230+
size_t hash_element_length = dim / (hash_element_size * 8);
231+
auto cur_query = (const char*)xq + (dim / 8) * index;
232+
minhash_jaccard_knn_ny(cur_query, (const char*)xb, hash_element_length, hash_element_size, nb,
233+
topk, bitset, cur_distances, cur_labels);
234+
} else {
235+
size_t u8_dim = dim / 8;
236+
auto cur_query = (const char*)xq + u8_dim * index;
237+
minhash_lsh_hit_ny(cur_query, (const char*)xb, u8_dim, band, nb, topk, bitset, cur_distances,
238+
cur_labels);
239+
}
240+
break;
241+
}
216242
case faiss::METRIC_Hamming: {
217243
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
218244
std::vector<int32_t> int_distances(topk);
@@ -306,6 +332,14 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
306332
int topk = cfg.k.value();
307333
auto labels = ids;
308334
auto distances = dis;
335+
// some check for minhash metric
336+
if (faiss_metric_type == faiss::METRIC_MinHash_Jaccard) {
337+
auto mh_valid_stat =
338+
MinhashConfigCheck(dim, datatype_v<DataType>, PARAM_TYPE::SEARCH | PARAM_TYPE::TRAIN, &cfg, &bitset);
339+
if (mh_valid_stat != Status::success) {
340+
return mh_valid_stat;
341+
}
342+
}
309343

310344
std::unique_ptr<float[]> norms = is_cosine ? GetVecNorms<DataType>(base_dataset) : nullptr;
311345
auto pool = ThreadPool::GetGlobalSearchThreadPool();
@@ -363,6 +397,23 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
363397
}
364398
break;
365399
}
400+
case faiss::METRIC_MinHash_Jaccard: {
401+
size_t band = cfg.band.value();
402+
bool search_with_jaccard = cfg.search_with_jaccard.value();
403+
if (search_with_jaccard) {
404+
size_t hash_element_size = cfg.mh_element_bit_width.value() / 8; // in bytes
405+
size_t hash_element_length = dim / (hash_element_size * 8);
406+
auto cur_query = (const char*)xq + (dim / 8) * index;
407+
minhash_jaccard_knn_ny(cur_query, (const char*)xb, hash_element_length, hash_element_size, nb,
408+
topk, bitset, cur_distances, cur_labels);
409+
} else {
410+
size_t u8_dim = dim / 8;
411+
auto cur_query = (const char*)xq + u8_dim * index;
412+
minhash_lsh_hit_ny(cur_query, (const char*)xb, u8_dim, band, nb, topk, bitset, cur_distances,
413+
cur_labels);
414+
}
415+
break;
416+
}
366417
case faiss::METRIC_Jaccard: {
367418
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
368419
faiss::float_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, cur_distances};
@@ -483,6 +534,10 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
483534
float range_filter = cfg.range_filter.value();
484535

485536
auto pool = ThreadPool::GetGlobalSearchThreadPool();
537+
// some check for minhash metric
538+
if (metric_str == metric::MHJACCARD) {
539+
return expected<DataSetPtr>::Err(Status::not_implemented, "minhash not support range search.");
540+
}
486541

487542
std::vector<std::vector<int64_t>> result_id_array(nq);
488543
std::vector<std::vector<float>> result_dist_array(nq);
@@ -758,6 +813,12 @@ BruteForce::AnnIterator(const DataSetPtr base_dataset, const DataSetPtr query_da
758813
return expected<std::vector<IndexNode::IteratorPtr>>::Err(result.error(), result.what());
759814
}
760815

816+
// some check for minhash metric
817+
if (metric_str == metric::MHJACCARD) {
818+
return expected<std::vector<IndexNode::IteratorPtr>>::Err(Status::not_implemented,
819+
"minhash does not support iterator.");
820+
}
821+
761822
#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT)
762823
// LCOV_EXCL_START
763824
std::shared_ptr<tracer::trace::Span> span = nullptr;

0 commit comments

Comments
 (0)