88#pragma once
99
1010#include < algorithm>
11+ #include < cassert>
1112#include < memory>
1213#include < numeric>
1314#include < utility>
@@ -26,6 +27,7 @@ class DaatMaxScoreSearcher : public RankedSearcher {
2627 typename IndexType::posting_list_iterator index_cursor;
2728 DimScorer scorer;
2829 float max_score;
30+ float qval_p1;
2931
3032 [[nodiscard]] uint32_t
3133 vec_id () const noexcept {
@@ -61,6 +63,12 @@ class DaatMaxScoreSearcher : public RankedSearcher {
6163 max_vec_id_(max_vec_id),
6264 row_sums_(index.get_row_sums()),
6365 scorer_type_(search_scorer->config ().scorer_type) {
66+ if (scorer_type_ == IndexScorerType::BM25) {
67+ const auto * bm25_scorer = dynamic_cast <const BM25IndexScorer*>(search_scorer.get ());
68+ assert (bm25_scorer != nullptr );
69+ bm25_p2_ = bm25_scorer->p2 ();
70+ bm25_p3_ = bm25_scorer->p3 ();
71+ }
6472 }
6573
6674 [[nodiscard]] auto
@@ -136,16 +144,27 @@ class DaatMaxScoreSearcher : public RankedSearcher {
136144
137145 current_score = 0 ;
138146 current_vec_id = std::exchange (next_vec_id, max_vec_id);
147+ float doc_norm = 0 .0f ;
139148
140149 if constexpr (ScorerType == IndexScorerType::BM25) {
141150 // Prefetch row_sums_ for next iterations that will be used by the BM25 scorer
142151 // Experiments show this prefetch pattern is optimal vs only prefetching next_vec_id
143152 __builtin_prefetch (&row_sums_[current_vec_id], 0 , 3 );
153+ doc_norm = bm25_p2_ + bm25_p3_ * row_sums_[current_vec_id];
144154 }
145155
156+ auto score_term = [&](auto & cursor) -> float {
157+ if constexpr (ScorerType == IndexScorerType::BM25) {
158+ const float tf = static_cast <float >(cursor.index_cursor .val ());
159+ return cursor.qval_p1 * tf / (tf + doc_norm);
160+ } else {
161+ return cursor.score ();
162+ }
163+ };
164+
146165 std::for_each (cursors.begin (), first_lookup, [&](auto & cursor) {
147166 if (cursor.vec_id () == current_vec_id) {
148- current_score += cursor. score ( );
167+ current_score += score_term (cursor );
149168 cursor.next ();
150169 if constexpr (ScorerType == IndexScorerType::BM25) {
151170 // Prefetch row_sums_ for next iterations that will be used by the BM25 scorer
@@ -168,7 +187,7 @@ class DaatMaxScoreSearcher : public RankedSearcher {
168187 }
169188 cursor.next_geq (current_vec_id);
170189 if (cursor.vec_id () == current_vec_id) {
171- current_score += cursor. score ( );
190+ current_score += score_term (cursor );
172191 }
173192 }
174193 }
@@ -200,9 +219,15 @@ class DaatMaxScoreSearcher : public RankedSearcher {
200219 float dim_max_score_ratio) {
201220 std::vector<Cursor> cursors;
202221 cursors.reserve (query.size ());
222+ const BM25IndexScorer* bm25_scorer = nullptr ;
223+ if (index_scorer->config ().scorer_type == IndexScorerType::BM25) {
224+ bm25_scorer = dynamic_cast <const BM25IndexScorer*>(index_scorer.get ());
225+ assert (bm25_scorer != nullptr );
226+ }
203227 for (const auto & [dim_id, dim_val] : query) {
204228 cursors.push_back (Cursor{index.get_dim_plist_cursor (dim_id, bitset), index_scorer->dim_scorer (dim_val),
205- dim_max_score_ratio * index.get_dim_max_score (dim_id, dim_val)});
229+ dim_max_score_ratio * index.get_dim_max_score (dim_id, dim_val),
230+ bm25_scorer != nullptr ? dim_val * bm25_scorer->p1 () : 0 .0f });
206231 }
207232 return cursors;
208233 }
@@ -212,6 +237,8 @@ class DaatMaxScoreSearcher : public RankedSearcher {
212237 // row_sums_ is only used for BM25 scorer
213238 const std::vector<float >& row_sums_;
214239 IndexScorerType scorer_type_;
240+ float bm25_p2_{0 .0f };
241+ float bm25_p3_{0 .0f };
215242};
216243
217244} // namespace knowhere::sparse::inverted
0 commit comments