diff --git a/include/pisa/accumulator/blocked_accumulator.hpp b/include/pisa/accumulator/blocked_accumulator.hpp new file mode 100644 index 000000000..7d03dcb40 --- /dev/null +++ b/include/pisa/accumulator/blocked_accumulator.hpp @@ -0,0 +1,84 @@ +#pragma once + +namespace pisa { + +template +struct Blocked_Accumulator { + + struct Proxy_Element { + std::ptrdiff_t document; + std::vector &accumulators; + std::vector &accumulators_max; + + Proxy_Element &operator=(float score) { + accumulators[document] = score; + auto &block_max = accumulators_max[document / block_size]; + if (score > block_max) { + block_max = score; + } + return *this; + } + Proxy_Element &operator+=(float delta) { + accumulators[document] += delta; + auto const&score = accumulators[document]; + auto &block_max = accumulators_max[document / block_size]; + if (score > block_max) { + block_max = score; + } + return *this; + } + + operator float() { return accumulators[document]; } + }; + + using reference = Proxy_Element; + + static_assert(block_size > 0, "must be positive"); + + [[nodiscard]] constexpr static auto calc_block_count(std::size_t size) noexcept -> std::size_t { + return (size + block_size - 1) / block_size; + } + + Blocked_Accumulator(std::size_t size) + : m_size(size), + m_block_count(calc_block_count(size)), m_accumulators(size), + m_accumulators_max(m_block_count) {} + + void init() { std::fill(m_accumulators.begin(), m_accumulators.end(), 0.0); } + + [[nodiscard]] auto operator[](std::ptrdiff_t document) -> Proxy_Element + { + return {document, m_accumulators, m_accumulators_max}; + } + + void accumulate(std::ptrdiff_t const document, float score_delta) + { + m_accumulators[document] += score_delta; + auto const &score = m_accumulators[document]; + auto &block_max = m_accumulators_max[document / block_size]; + if (score > block_max) { + block_max = score; + } + } + + void aggregate(topk_queue &topk) { + for (size_t block = 0; block < m_block_count; ++block) { + if (not topk.would_enter(m_accumulators_max[block])) { continue; } + uint32_t doc = block * block_size; + uint32_t end = std::min((block + 1) * block_size, m_accumulators.size()); + for (; doc < end; ++doc) { + topk.insert(m_accumulators[doc], doc); + } + } + } + + [[nodiscard]] auto size() noexcept -> std::size_t { return m_size; } + + private: + std::size_t m_size; + std::size_t m_block_count; + std::vector m_accumulators; + std::vector m_accumulators_max; +}; + +} // pisa diff --git a/include/pisa/query/algorithm/maxscore_taat_query.hpp b/include/pisa/query/algorithm/maxscore_taat_query.hpp new file mode 100644 index 000000000..48548fa61 --- /dev/null +++ b/include/pisa/query/algorithm/maxscore_taat_query.hpp @@ -0,0 +1,145 @@ +#pragma once + +#include "topk_queue.hpp" +#include "util/intrinsics.hpp" + +#include "accumulator/blocked_accumulator.hpp" +#include "accumulator/lazy_accumulator.hpp" +#include "accumulator/simple_accumulator.hpp" + +namespace pisa { + +template +[[nodiscard]] auto max_weights(Index const &index, WandType const &wdata, term_id_vec terms) { + // TODO(michal): parametrize scorer_type; didn't do that because this might mean some more + // complex refactoring I want to avoid for now. + using scorer_type = bm25; + using cursor_type = typename Index::document_enumerator; + using score_function_type = Score_Function; + + auto query_term_freqs = query_freqs(terms); + std::vector max_weights; + max_weights.reserve(query_term_freqs.size()); + + for (auto term : query_term_freqs) { + auto list = index[term.first]; + auto q_weight = scorer_type::query_term_weight(term.second, list.size(), index.num_docs()); + max_weights.push_back(q_weight * wdata.max_term_weight(term.first)); + } + return max_weights; +} + +template +std::vector sort_permutation(Container const &container, Function sort_function) { + std::vector p(container.size()); + std::iota(p.begin(), p.end(), 0); + std::sort(p.begin(), p.end(), [&](std::size_t i, std::size_t j) { + return sort_function(container[i], container[j]); + }); + return p; +} + +template +void apply_permutation(Container &container, const std::vector &p) { + std::vector done(container.size()); + for (std::size_t i = 0; i < container.size(); ++i) { + if (done[i]) { + continue; + } + done[i] = true; + std::size_t prev_j = i; + std::size_t j = p[i]; + while (i != j) { + std::swap(container[prev_j], container[j]); + done[j] = true; + prev_j = j; + j = p[j]; + } + } +} + +template +void sort_many(Container &key_container, Function sort_function, Containers &... containers) { + auto permutation = sort_permutation(key_container, sort_function); + (apply_permutation(containers, permutation), ...); +} + +template +class maxscore_taat_query { + using accumulator_reference = typename Acc::reference; + using score_function_type = Score_Function; + + public: + maxscore_taat_query(Index const &index, WandType const &wdata, uint64_t k) + : m_index(index), m_wdata(wdata), m_k(k), m_topk(k), m_accumulators(index.num_docs()) {} + + uint64_t operator()(term_id_vec terms) { + m_topk.clear(); + auto cws = query::cursors_with_scores(m_index, m_wdata, terms); + auto cursors = cws.first; + auto score_functions = cws.second; + auto m_w = max_weights(m_index, m_wdata, terms); + if (cursors.empty()) { + return 0; + } + sort_many( + m_w, [](auto lhs, auto rhs) { return lhs > rhs; }, cursors, score_functions); + + float nonessential_sum = std::accumulate(m_w.begin(), m_w.end(), 0.0); + m_accumulators.init(); + uint32_t term = 0; + for (; term < cursors.size(); ++term) { + if (not m_topk.would_enter(nonessential_sum)) { + break; + } + m_topk.clear(); + auto cursor = cursors[term]; + auto score = score_functions[term]; + // TODO(antonio): basically here we can do a bit better. + // before scoring a document, we read its accumulator value and check if the sum of + // the accumulator value and the upper bound of the maxscores of the missing terms + // (current included) is greater than the threshold. If it is we score and add it to the + // accumulator, we go to the next document otherwise. + for (; cursor.docid() < m_accumulators.size(); cursor.next()) { + if(m_topk.would_enter(nonessential_sum + m_accumulators[cursor.docid()])) { + m_accumulators.accumulate(cursor.docid(), score(cursor.docid(), cursor.freq())); + m_topk.insert(m_accumulators[cursor.docid()]); + } + } + nonessential_sum -= m_w[term]; + } + + for (; term < cursors.size(); ++term) { + auto cursor = cursors[term]; + auto score = score_functions[term]; + for (; cursor.docid() < m_accumulators.size(); cursor.next()) { + accumulator_reference accumulator = m_accumulators[cursor.docid()]; + if (accumulator > 0) { + accumulator += score(cursor.docid(), cursor.freq()); + } + } + } + + m_topk.clear(); + m_accumulators.aggregate(m_topk); + m_topk.finalize(); + return m_topk.topk().size(); + } + + + std::vector> const &topk() const { return m_topk.topk(); } + + private: + Index const & m_index; + WandType const &m_wdata; + int m_k; + topk_queue m_topk; + Acc m_accumulators; +}; + +template +[[nodiscard]] auto make_maxscore_taat_query(Index const &index, WandType const &wdata, uint64_t k) { + return maxscore_taat_query(index, wdata, k); +} + +}; // namespace pisa diff --git a/include/pisa/query/queries.hpp b/include/pisa/query/queries.hpp index 5554e4c2c..9935ffe84 100644 --- a/include/pisa/query/queries.hpp +++ b/include/pisa/query/queries.hpp @@ -102,3 +102,4 @@ template #include "algorithm/ranked_or_query.hpp" #include "algorithm/wand_query.hpp" #include "algorithm/ranked_or_taat_query.hpp" +#include "algorithm/maxscore_taat_query.hpp" diff --git a/src/queries.cpp b/src/queries.cpp index 30b234d72..049fc4c41 100644 --- a/src/queries.cpp +++ b/src/queries.cpp @@ -142,6 +142,14 @@ void perftest(const std::string &index_filename, } else if (t == "ranked_or_taat_lazy" && wand_data_filename) { query_fun = pisa::make_ranked_or_taat_query>(index, wdata, k); + } else if (t == "ranked_or_taat_blocked" && wand_data_filename) { + query_fun = + pisa::make_ranked_or_taat_query>(index, wdata, k); + } else if (t == "maxscore_taat" && wand_data_filename) { + query_fun = pisa::make_maxscore_taat_query(index, wdata, k); + } else if (t == "maxscore_taat_blocked" && wand_data_filename) { + query_fun = + pisa::make_maxscore_taat_query>(index, wdata, k); } else { spdlog::error("Unsupported query type: {}", t); break; diff --git a/test/test_ranked_queries.cpp b/test/test_ranked_queries.cpp index b5ac0a5e2..5574a6cae 100644 --- a/test/test_ranked_queries.cpp +++ b/test/test_ranked_queries.cpp @@ -100,6 +100,27 @@ TEST_CASE_METHOD(pisa::test::index_initialization, "block_max_maxscore") test_against_or(bmm_q); } +TEST_CASE_METHOD(pisa::test::index_initialization, "maxscore_taat") +{ + pisa::ranked_or_taat_query ranked_or_taat_q( + index, wdata, 10); + test_against_or(ranked_or_taat_q); +} + +TEST_CASE_METHOD(pisa::test::index_initialization, "ranked_or_taat_lazy") +{ + pisa::ranked_or_taat_query> ranked_or_taat_q( + index, wdata, 10); + test_against_or(ranked_or_taat_q); +} + +TEST_CASE_METHOD(pisa::test::index_initialization, "maxscore_taat_blocked") +{ + pisa::maxscore_taat_query> + taat_q(index, wdata, 10); + test_against_or(taat_q); +} + TEST_CASE_METHOD(pisa::test::index_initialization, "ranked_or_taat") { @@ -108,7 +129,14 @@ TEST_CASE_METHOD(pisa::test::index_initialization, "ranked_or_taat") test_against_or(ranked_or_taat_q); } -TEST_CASE_METHOD(pisa::test::index_initialization, "ranked_or_taat_lazy") +TEST_CASE_METHOD(pisa::test::index_initialization, "ranked_or_taat_blocked") +{ + pisa::ranked_or_taat_query> + ranked_or_taat_q(index, wdata, 10); + test_against_or(ranked_or_taat_q); +} + +TEST_CASE_METHOD(pisa::test::index_initialization, "ranked_or_taat_query_lazy") { pisa::ranked_or_taat_query> ranked_or_taat_q( index, wdata, 10);