Skip to content

Commit 3b3f6a5

Browse files
authored
Optimize BM25 scoring in DAAT MaxScore (#1629)
Signed-off-by: lyang24 <lanqingy93@gmail.com>
1 parent 0e53e57 commit 3b3f6a5

2 files changed

Lines changed: 48 additions & 4 deletions

File tree

src/index/sparse/scorer.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <cstdint>
44
#include <functional>
5+
#include <vector>
56

67
namespace knowhere::sparse::inverted {
78

@@ -86,8 +87,9 @@ struct BM25IndexScorer : public IndexScorer {
8687
// In senario of BM25, qval is IDF value, rval is TF value
8788
[[nodiscard]] DimScorer
8889
dim_scorer(float qval) const override {
90+
const float qval_p1 = qval * p1_;
8991
return
90-
[&, qval](uint32_t rid, uint32_t rval) { return qval * p1_ * rval / (rval + p2_ + p3_ * row_sums_[rid]); };
92+
[&, qval_p1](uint32_t rid, uint32_t rval) { return qval_p1 * rval / (rval + p2_ + p3_ * row_sums_[rid]); };
9193
}
9294

9395
[[nodiscard]] float
@@ -100,6 +102,21 @@ struct BM25IndexScorer : public IndexScorer {
100102
return row_sums_;
101103
}
102104

105+
[[nodiscard]] float
106+
p1() const {
107+
return p1_;
108+
}
109+
110+
[[nodiscard]] float
111+
p2() const {
112+
return p2_;
113+
}
114+
115+
[[nodiscard]] float
116+
p3() const {
117+
return p3_;
118+
}
119+
103120
protected:
104121
const float p1_;
105122
const float p2_;

src/index/sparse/searcher/daat_maxscore.h

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#pragma once
99

1010
#include <algorithm>
11+
#include <cassert>
1112
#include <memory>
1213
#include <numeric>
1314
#include <utility>
@@ -26,6 +27,7 @@ class DaatMaxScoreSearcher : public RankedSearcher {
2627
typename IndexType::posting_list_iterator index_cursor;
2728
DimScorer scorer;
2829
float max_score;
30+
float qval_p1;
2931

3032
[[nodiscard]] uint32_t
3133
vec_id() const noexcept {
@@ -61,6 +63,12 @@ class DaatMaxScoreSearcher : public RankedSearcher {
6163
max_vec_id_(max_vec_id),
6264
row_sums_(index.get_row_sums()),
6365
scorer_type_(search_scorer->config().scorer_type) {
66+
if (scorer_type_ == IndexScorerType::BM25) {
67+
const auto* bm25_scorer = dynamic_cast<const BM25IndexScorer*>(search_scorer.get());
68+
assert(bm25_scorer != nullptr);
69+
bm25_p2_ = bm25_scorer->p2();
70+
bm25_p3_ = bm25_scorer->p3();
71+
}
6472
}
6573

6674
[[nodiscard]] auto
@@ -136,16 +144,27 @@ class DaatMaxScoreSearcher : public RankedSearcher {
136144

137145
current_score = 0;
138146
current_vec_id = std::exchange(next_vec_id, max_vec_id);
147+
float doc_norm = 0.0f;
139148

140149
if constexpr (ScorerType == IndexScorerType::BM25) {
141150
// Prefetch row_sums_ for next iterations that will be used by the BM25 scorer
142151
// Experiments show this prefetch pattern is optimal vs only prefetching next_vec_id
143152
__builtin_prefetch(&row_sums_[current_vec_id], 0, 3);
153+
doc_norm = bm25_p2_ + bm25_p3_ * row_sums_[current_vec_id];
144154
}
145155

156+
auto score_term = [&](auto& cursor) -> float {
157+
if constexpr (ScorerType == IndexScorerType::BM25) {
158+
const float tf = static_cast<float>(cursor.index_cursor.val());
159+
return cursor.qval_p1 * tf / (tf + doc_norm);
160+
} else {
161+
return cursor.score();
162+
}
163+
};
164+
146165
std::for_each(cursors.begin(), first_lookup, [&](auto& cursor) {
147166
if (cursor.vec_id() == current_vec_id) {
148-
current_score += cursor.score();
167+
current_score += score_term(cursor);
149168
cursor.next();
150169
if constexpr (ScorerType == IndexScorerType::BM25) {
151170
// Prefetch row_sums_ for next iterations that will be used by the BM25 scorer
@@ -168,7 +187,7 @@ class DaatMaxScoreSearcher : public RankedSearcher {
168187
}
169188
cursor.next_geq(current_vec_id);
170189
if (cursor.vec_id() == current_vec_id) {
171-
current_score += cursor.score();
190+
current_score += score_term(cursor);
172191
}
173192
}
174193
}
@@ -200,9 +219,15 @@ class DaatMaxScoreSearcher : public RankedSearcher {
200219
float dim_max_score_ratio) {
201220
std::vector<Cursor> cursors;
202221
cursors.reserve(query.size());
222+
const BM25IndexScorer* bm25_scorer = nullptr;
223+
if (index_scorer->config().scorer_type == IndexScorerType::BM25) {
224+
bm25_scorer = dynamic_cast<const BM25IndexScorer*>(index_scorer.get());
225+
assert(bm25_scorer != nullptr);
226+
}
203227
for (const auto& [dim_id, dim_val] : query) {
204228
cursors.push_back(Cursor{index.get_dim_plist_cursor(dim_id, bitset), index_scorer->dim_scorer(dim_val),
205-
dim_max_score_ratio * index.get_dim_max_score(dim_id, dim_val)});
229+
dim_max_score_ratio * index.get_dim_max_score(dim_id, dim_val),
230+
bm25_scorer != nullptr ? dim_val * bm25_scorer->p1() : 0.0f});
206231
}
207232
return cursors;
208233
}
@@ -212,6 +237,8 @@ class DaatMaxScoreSearcher : public RankedSearcher {
212237
// row_sums_ is only used for BM25 scorer
213238
const std::vector<float>& row_sums_;
214239
IndexScorerType scorer_type_;
240+
float bm25_p2_{0.0f};
241+
float bm25_p3_{0.0f};
215242
};
216243

217244
} // namespace knowhere::sparse::inverted

0 commit comments

Comments
 (0)