Skip to content

Commit 5d6cde2

Browse files
committed
enhance: knowhere support metric mhjaccard and index minhash lsh for minhash vector
Signed-off-by: cqy123456 <qianya.cheng@zilliz.com>
1 parent b975bae commit 5d6cde2

33 files changed

Lines changed: 2404 additions & 20 deletions

cmake/libs/libfaiss.cmake

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ if(__X86_64)
5050
${UTILS_SRC} $<TARGET_OBJECTS:utils_sse> $<TARGET_OBJECTS:utils_avx>
5151
$<TARGET_OBJECTS:utils_avx512> $<TARGET_OBJECTS:utils_avx512icx>)
5252
target_link_libraries(knowhere_utils PUBLIC glog::glog)
53+
target_link_libraries(knowhere_utils PUBLIC xxHash::xxhash)
5354
endif()
5455

5556
if(__AARCH64)
@@ -99,19 +100,22 @@ if(__AARCH64)
99100
endif()
100101

101102
target_link_libraries(knowhere_utils PUBLIC glog::glog)
103+
target_link_libraries(knowhere_utils PUBLIC xxHash::xxhash)
102104
endif()
103105

104106
if(__RISCV64)
105107
set(UTILS_SRC src/simd/hook.cc src/simd/distances_ref.cc)
106108
add_library(knowhere_utils STATIC ${UTILS_SRC})
107109
target_link_libraries(knowhere_utils PUBLIC glog::glog)
110+
target_link_libraries(knowhere_utils PUBLIC xxHash::xxhash)
108111
endif()
109112

110113
# ToDo: Add distances_vsx.cc for powerpc64 SIMD acceleration
111114
if(__PPC64)
112115
set(UTILS_SRC src/simd/hook.cc src/simd/distances_ref.cc src/simd/distances_powerpc.cc)
113116
add_library(knowhere_utils STATIC ${UTILS_SRC})
114117
target_link_libraries(knowhere_utils PUBLIC glog::glog)
118+
target_link_libraries(knowhere_utils PUBLIC xxHash::xxhash)
115119
endif()
116120

117121

@@ -131,6 +135,9 @@ else()
131135
find_package(BLAS REQUIRED)
132136
endif()
133137

138+
find_package(xxHash REQUIRED)
139+
include_directories(${xxHash_INCLUDE_DIRS})
140+
134141
if(__X86_64)
135142
list(REMOVE_ITEM FAISS_SRCS ${FAISS_AVX2_SRCS})
136143

conanfile.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def requirements(self):
106106
self.requires("folly/2023.10.30.09@milvus/dev")
107107
self.requires("libcurl/8.2.1")
108108
self.requires("simde/0.8.2")
109+
self.requires("xxhash/0.8.3")
109110
if self.settings.os == "Android":
110111
self.requires("openblas/0.3.27")
111112
if not self.options.with_light:
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
//
2+
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
3+
// with the License. You may obtain a copy of the License at
4+
//
5+
// http://www.apache.org/licenses/LICENSE-2.0
6+
//
7+
// Unless required by applicable law or agreed to in writing, software distributed under the License
8+
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
9+
// or implied. See the License for the specific language governing permissions and limitations under the License.
10+
#ifndef KNOWHERE_KNOWHERE_H
11+
#define KNOWHERE_KNOWHERE_H
12+
#include <cmath>
13+
#include <cstring>
14+
#include <fstream>
15+
#include <functional>
16+
#include <vector>
17+
18+
#include "io/memory_io.h"
19+
#include "knowhere/bitsetview.h"
20+
#include "knowhere/utils.h"
21+
namespace knowhere {
22+
template <typename T>
23+
class BloomFilter {
24+
public:
25+
explicit BloomFilter(size_t expected_elements, double false_positive_prob)
26+
: n(expected_elements), p(false_positive_prob) {
27+
m = static_cast<size_t>(-(n * log(p)) / (log(2) * log(2)));
28+
k = static_cast<int>(m / n * log(2));
29+
m = std::max<size_t>(m, 1);
30+
k = std::max(k, 1);
31+
bits.resize(m, false);
32+
}
33+
34+
void
35+
add(const T& element) {
36+
size_t glb_hash = hash((const char*)&element, sizeof(element), 0);
37+
for (int i = 0; i < k; ++i) {
38+
size_t pos = (glb_hash + i) % m;
39+
bits[pos] = true;
40+
}
41+
}
42+
43+
bool
44+
contains(const T& element) const {
45+
size_t glb_hash = hash((const char*)&element, sizeof(element), 0);
46+
for (int i = 0; i < k; ++i) {
47+
size_t pos = (glb_hash + i) % m;
48+
if (!bits[pos])
49+
return false;
50+
}
51+
return true;
52+
}
53+
54+
void
55+
save(MemoryIOWriter& writer) const {
56+
writeBinaryPOD(writer, m);
57+
writeBinaryPOD(writer, k);
58+
writeBinaryPOD(writer, n);
59+
writeBinaryPOD(writer, p);
60+
auto bytes_num = (m + 8 - 1) / 8;
61+
std::vector<char> buffer(bytes_num, 0);
62+
for (size_t i = 0; i < m; ++i) {
63+
if (bits[i]) {
64+
buffer[i / 8] |= (1 << (i % 8));
65+
}
66+
}
67+
writer.write(buffer.data(), buffer.size());
68+
}
69+
70+
void
71+
load(MemoryIOReader& reader) {
72+
readBinaryPOD(reader, m);
73+
readBinaryPOD(reader, k);
74+
readBinaryPOD(reader, n);
75+
readBinaryPOD(reader, p);
76+
bits.clear();
77+
bits.resize(m);
78+
auto bytes_num = (m + 8 - 1) / 8;
79+
std::vector<char> buffer(bytes_num);
80+
reader.read(buffer.data(), bytes_num);
81+
for (size_t i = 0; i < m; ++i) {
82+
bool bit = (buffer[i / 8] >> (i % 8)) & 1;
83+
bits.push_back(bit);
84+
}
85+
}
86+
size_t
87+
size() const {
88+
return n;
89+
}
90+
double
91+
false_positive_rate() const {
92+
return p;
93+
}
94+
size_t
95+
memory_usage() const {
96+
return m / 8;
97+
}
98+
99+
private:
100+
static constexpr size_t multiplier = 31;
101+
std::vector<bool> bits;
102+
size_t n = 0;
103+
double p = 0.0;
104+
size_t m = 0;
105+
int k = 0;
106+
107+
size_t
108+
hash(const char* data, size_t length, size_t bucket_i) const {
109+
if (data == nullptr) {
110+
throw std::runtime_error("can't hash null data.");
111+
}
112+
size_t result = 0;
113+
for (size_t i = 0; i < length; ++i) {
114+
result = (result * multiplier) + static_cast<size_t>(data[i]);
115+
}
116+
return (result + bucket_i) % m;
117+
}
118+
};
119+
} // namespace knowhere
120+
#endif

include/knowhere/comp/index_param.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ constexpr const char* INDEX_HNSW_PQ = "HNSW_PQ";
5656
constexpr const char* INDEX_HNSW_PRQ = "HNSW_PRQ";
5757

5858
constexpr const char* INDEX_DISKANN = "DISKANN";
59+
constexpr const char* INDEX_MINHASH_LSH = "MINHASH_LSH";
5960

6061
constexpr const char* INDEX_SPARSE_INVERTED_INDEX = "SPARSE_INVERTED_INDEX";
6162
constexpr const char* INDEX_SPARSE_WAND = "SPARSE_WAND";
@@ -187,6 +188,17 @@ constexpr const char* DROP_RATIO_SEARCH = "drop_ratio_search";
187188
// RaBitQ Params
188189
constexpr const char* RABITQ_QUERY_BITS = "rbq_bits_query";
189190

191+
// minhash meta Params
192+
constexpr const char* MH_ELEMENT_BIT_WIDTH = "mh_element_bit_width";
193+
constexpr const char* MH_LSH_SEARCH_WITH_JACCARD = "mh_search_with_jaccard";
194+
// minhash lsh index params
195+
constexpr const char* MH_LSH_ALIGNED_BLOCK_SIZE = "mh_lsh_aligned_block_size";
196+
constexpr const char* MH_LSH_BAND = "mh_lsh_band";
197+
constexpr const char* MH_LSH_SHARED_BLOOM_FILTER = "mh_lsh_shared_bloom_filter";
198+
constexpr const char* MH_LSH_BLOOM_FALSE_POSITIVE_RPOB = "mh_lsh_bloom_false_positive_prob";
199+
constexpr const char* MH_LSH_HASH_CODE_IN_MEM = "mh_lsh_code_in_mem";
200+
constexpr const char* MH_LSH_REFINE_K = "refine_k";
201+
constexpr const char* MH_LSH_BATCH_SEARCH = "mh_lsh_batch_search";
190202
} // namespace indexparam
191203

192204
using MetricType = std::string;
@@ -197,6 +209,7 @@ constexpr const char* L2 = "L2";
197209
constexpr const char* COSINE = "COSINE";
198210
constexpr const char* HAMMING = "HAMMING";
199211
constexpr const char* JACCARD = "JACCARD";
212+
constexpr const char* MHJACCARD = "MHJACCARD";
200213
constexpr const char* SUBSTRUCTURE = "SUBSTRUCTURE";
201214
constexpr const char* SUPERSTRUCTURE = "SUPERSTRUCTURE";
202215
constexpr const char* BM25 = "BM25";

include/knowhere/config.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,12 @@ class BaseConfig : public Config {
633633
*/
634634
CFG_INT refine_type;
635635
CFG_BOOL refine_with_quant;
636+
/*
637+
* mh_lsh_band is a special parameters of BF search and MinHash index node train.
638+
*/
639+
CFG_INT mh_lsh_band;
640+
CFG_BOOL mh_search_with_jaccard;
641+
CFG_INT mh_element_bit_width;
636642
KNOHWERE_DECLARE_CONFIG(BaseConfig) {
637643
KNOWHERE_CONFIG_DECLARE_FIELD(dim).allow_empty_without_default().description("vector dim").for_train();
638644
KNOWHERE_CONFIG_DECLARE_FIELD(metric_type)
@@ -781,6 +787,21 @@ class BaseConfig : public Config {
781787
.for_search()
782788
.for_range_search()
783789
.for_iterator();
790+
KNOWHERE_CONFIG_DECLARE_FIELD(mh_lsh_band)
791+
.description("param of MinHashLSH")
792+
.set_default(1)
793+
.for_train()
794+
.for_search();
795+
KNOWHERE_CONFIG_DECLARE_FIELD(mh_element_bit_width)
796+
.description("sizeof(hash code), the hash element should be aligned on 8 bits")
797+
.set_default(8)
798+
.set_range(8, 256)
799+
.for_train()
800+
.for_search();
801+
KNOWHERE_CONFIG_DECLARE_FIELD(mh_search_with_jaccard)
802+
.description("return the jaccard distance of minhash vector search or minhashlsh hit flag.")
803+
.set_default(false)
804+
.for_search();
784805
}
785806
};
786807
} // namespace knowhere

include/knowhere/index/index_factory.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,10 @@ class IndexFactory {
143143
#define KNOWHERE_MOCK_REGISTER_DENSE_INT_GLOBAL(name, index_node, features, ...) \
144144
KNOWHERE_MOCK_REGISTER_GLOBAL(name, index_node, int8, (features | knowhere::feature::INT8), ##__VA_ARGS__);
145145

146+
// register vector index supporting binary data types
147+
#define KNOWHERE_MOCK_REGISTER_DENSE_BINARY_ALL_GLOBAL(name, index_node, features, ...) \
148+
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bin1, (features | knowhere::feature::BINARY), ##__VA_ARGS__);
149+
146150
// register vector index supporting ALL_DENSE_FLOAT_TYPE(float32, bf16, fp16) data types, but mocked bf16 and fp16
147151
#define KNOWHERE_MOCK_REGISTER_DENSE_FLOAT_ALL_GLOBAL(name, index_node, features, ...) \
148152
KNOWHERE_MOCK_REGISTER_GLOBAL(name, index_node, bf16, (features | knowhere::feature::BF16), ##__VA_ARGS__); \

include/knowhere/index/index_table.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ static std::set<std::pair<std::string, VecType>> legal_knowhere_index = {
9898
// sparse index
9999
{IndexEnum::INDEX_SPARSE_INVERTED_INDEX, VecType::VECTOR_SPARSE_FLOAT},
100100
{IndexEnum::INDEX_SPARSE_WAND, VecType::VECTOR_SPARSE_FLOAT},
101+
// minhash index
102+
{IndexEnum::INDEX_MINHASH_LSH, VecType::VECTOR_BINARY},
101103
};
102104

103105
static std::set<std::string> legal_support_mmap_knowhere_index = {

include/knowhere/utils.h

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <strings.h>
1515

1616
#include <algorithm>
17+
#include <fstream>
1718
#include <vector>
1819

1920
#include "knowhere/binaryset.h"
@@ -226,6 +227,67 @@ readBinaryPOD(R& in, T& podRef) {
226227
in.read((char*)&podRef, sizeof(T));
227228
}
228229

230+
template <typename T>
231+
inline void
232+
load_vec_meta(const std::string& bin_file, size_t& rows, size_t& dim) {
233+
uint32_t u32_rows, u32_dim;
234+
std::ifstream file(bin_file, std::ios::binary | std::ios::ate);
235+
if (!file.is_open()) {
236+
throw std::runtime_error("fail to open file: " + bin_file);
237+
}
238+
size_t autual_file_size = file.tellg();
239+
file.seekg(0, std::ios::beg);
240+
file.read(reinterpret_cast<char*>(&u32_rows), sizeof(uint32_t));
241+
file.read(reinterpret_cast<char*>(&u32_dim), sizeof(uint32_t));
242+
rows = u32_rows;
243+
dim = u32_dim;
244+
// check data dim and size
245+
size_t expect_file_size = 0;
246+
if constexpr (std::is_same_v<T, bin1>) {
247+
if (dim % 8 != 0) {
248+
throw std::runtime_error("fail to load binary vector base file, dim % 8 != 0 ");
249+
}
250+
expect_file_size = rows * dim / 8 + 2 * sizeof(uint32_t);
251+
} else {
252+
expect_file_size = rows * dim * sizeof(T) + 2 * sizeof(uint32_t);
253+
}
254+
if (autual_file_size != expect_file_size) {
255+
throw std::runtime_error("fail to get raw data meta, file size mismatch of raw data file.");
256+
}
257+
}
258+
259+
template <typename T>
260+
inline void
261+
load_vec_data(const std::string& bin_file, std::unique_ptr<char[]>& data, size_t& npts, size_t& dim) {
262+
std::ifstream file(bin_file, std::ios::binary | std::ios::ate);
263+
if (!file.is_open()) {
264+
throw std::runtime_error("fail to open file: " + bin_file);
265+
}
266+
size_t autual_file_size = file.tellg();
267+
file.seekg(0, std::ios::beg);
268+
uint32_t n, d;
269+
file.read(reinterpret_cast<char*>(&n), sizeof(uint32_t));
270+
file.read(reinterpret_cast<char*>(&d), sizeof(uint32_t));
271+
npts = n;
272+
dim = d;
273+
size_t expect_file_size = 0;
274+
// check data dim and size
275+
if constexpr (std::is_same_v<T, bin1>) {
276+
if (dim % 8 != 0) {
277+
throw std::runtime_error("fail to load binary vector base file, dim % 8 != 0 ");
278+
}
279+
expect_file_size = npts * dim / 8 + 2 * sizeof(uint32_t);
280+
} else {
281+
expect_file_size = npts * dim * sizeof(T) + 2 * sizeof(uint32_t);
282+
}
283+
if (autual_file_size != expect_file_size) {
284+
throw std::runtime_error("fail to load raw data, file size mismatch of raw data file.");
285+
}
286+
uint64_t total_size = dim * npts / 8;
287+
data = std::make_unique<char[]>(total_size);
288+
file.read(reinterpret_cast<char*>(data.get()), total_size);
289+
}
290+
229291
// taken from
230292
// https://github.com/Microsoft/BLAS-on-flash/blob/master/include/utils.h
231293
// round up X to the nearest multiple of Y

0 commit comments

Comments
 (0)