diff --git a/CMakeLists.txt b/CMakeLists.txt index 0c0e97139..5d84a7e2d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,7 +40,6 @@ list(APPEND LCOV_REMOVE_PATTERNS "'${PROJECT_SOURCE_DIR}/external/*'") if (UNIX) - # For hardware popcount and other special instructions set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native") @@ -62,8 +61,6 @@ endif() set(THREADS_PREFER_PTHREAD_FLAG ON) find_package(Threads REQUIRED) -link_libraries(Threads::Threads) - include_directories(include) add_library(pisa INTERFACE) diff --git a/include/pisa/accumulator/simple_accumulator.hpp b/include/pisa/accumulator/simple_accumulator.hpp new file mode 100644 index 000000000..07d34f5ba --- /dev/null +++ b/include/pisa/accumulator/simple_accumulator.hpp @@ -0,0 +1,15 @@ +#pragma once + +namespace pisa { + +struct Simple_Accumulator : public std::vector { + Simple_Accumulator(std::ptrdiff_t size) : std::vector(size) {} + void init() { std::fill(begin(), end(), 0.0); } + void accumulate(uint32_t doc, float score) { operator[](doc) += score; } + void aggregate(topk_queue &topk) { + uint64_t docid = 0u; + std::for_each(begin(), end(), [&](auto score) { topk.insert(score, docid++); }); + } +}; + +} \ No newline at end of file diff --git a/include/pisa/block_posting_list.hpp b/include/pisa/block_posting_list.hpp index ac29c70f3..f2cd24c81 100644 --- a/include/pisa/block_posting_list.hpp +++ b/include/pisa/block_posting_list.hpp @@ -82,6 +82,7 @@ namespace pisa { class document_enumerator { public: + document_enumerator(uint8_t const* data, uint64_t universe, size_t term_id = 0) : m_n(0) // just to silence warnings diff --git a/include/pisa/freq_index.hpp b/include/pisa/freq_index.hpp index 3aaf0bed6..82caaffa3 100644 --- a/include/pisa/freq_index.hpp +++ b/include/pisa/freq_index.hpp @@ -76,6 +76,7 @@ namespace pisa { class document_enumerator { public: + void reset() { m_cur_pos = 0; diff --git a/include/pisa/query/algorithm/and_query.hpp b/include/pisa/query/algorithm/and_query.hpp index f9d36f865..3601b246b 100644 --- a/include/pisa/query/algorithm/and_query.hpp +++ b/include/pisa/query/algorithm/and_query.hpp @@ -2,11 +2,12 @@ namespace pisa { -template +template struct and_query { - template - uint64_t operator()(Index const &index, term_id_vec terms) const { + and_query(Index const &index) : m_index(index) {} + + uint64_t operator()(term_id_vec terms) const { if (terms.empty()) return 0; remove_duplicate_terms(terms); @@ -16,7 +17,7 @@ struct and_query { enums.reserve(terms.size()); for (auto term : terms) { - enums.push_back(index[term]); + enums.push_back(m_index[term]); } // sort by increasing frequency @@ -27,7 +28,7 @@ struct and_query { uint64_t results = 0; uint64_t candidate = enums[0].docid(); size_t i = 1; - while (candidate < index.num_docs()) { + while (candidate < m_index.num_docs()) { for (; i < enums.size(); ++i) { enums[i].next_geq(candidate); if (enums[i].docid() != candidate) { @@ -52,6 +53,9 @@ struct and_query { } return results; } + + private: + Index const &m_index; }; } // namespace pisa \ No newline at end of file diff --git a/include/pisa/query/algorithm/block_max_maxscore_query.hpp b/include/pisa/query/algorithm/block_max_maxscore_query.hpp index 5ef880321..a0a95dee1 100644 --- a/include/pisa/query/algorithm/block_max_maxscore_query.hpp +++ b/include/pisa/query/algorithm/block_max_maxscore_query.hpp @@ -2,22 +2,22 @@ namespace pisa { -template +template struct block_max_maxscore_query { typedef bm25 scorer_type; - block_max_maxscore_query(WandType const &wdata, uint64_t k) : m_wdata(&wdata), m_topk(k) {} + block_max_maxscore_query(Index const &index, WandType const &wdata, uint64_t k) + : m_index(index), m_wdata(&wdata), m_topk(k) {} - template - uint64_t operator()(Index const &index, term_id_vec const &terms) { + uint64_t operator()(term_id_vec const &terms) { m_topk.clear(); if (terms.empty()) return 0; auto query_term_freqs = query_freqs(terms); - uint64_t num_docs = index.num_docs(); + uint64_t num_docs = m_index.num_docs(); typedef typename Index::document_enumerator enum_type; typedef typename WandType::wand_data_enumerator wdata_enum; @@ -32,7 +32,7 @@ struct block_max_maxscore_query { enums.reserve(query_term_freqs.size()); for (auto term : query_term_freqs) { - auto list = index[term.first]; + auto list = m_index[term.first]; auto w_enum = m_wdata->getenum(term.first); auto q_weight = scorer_type::query_term_weight(term.second, list.size(), num_docs); auto max_weight = q_weight * m_wdata->max_term_weight(term.first); @@ -66,10 +66,10 @@ struct block_max_maxscore_query { }) ->docs_enum.docid(); - while (non_essential_lists < ordered_enums.size() && cur_doc < index.num_docs()) { + while (non_essential_lists < ordered_enums.size() && cur_doc < m_index.num_docs()) { float score = 0; float norm_len = m_wdata->norm_len(cur_doc); - uint64_t next_doc = index.num_docs(); + uint64_t next_doc = m_index.num_docs(); for (size_t i = non_essential_lists; i < ordered_enums.size(); ++i) { if (ordered_enums[i]->docs_enum.docid() == cur_doc) { score += @@ -129,6 +129,7 @@ struct block_max_maxscore_query { std::vector> const &topk() const { return m_topk.topk(); } private: + Index const & m_index; WandType const *m_wdata; topk_queue m_topk; }; diff --git a/include/pisa/query/algorithm/block_max_wand_query.hpp b/include/pisa/query/algorithm/block_max_wand_query.hpp index 1bbd0c396..138083839 100644 --- a/include/pisa/query/algorithm/block_max_wand_query.hpp +++ b/include/pisa/query/algorithm/block_max_wand_query.hpp @@ -2,20 +2,20 @@ namespace pisa { -template +template struct block_max_wand_query { typedef bm25 scorer_type; - block_max_wand_query(WandType const &wdata, uint64_t k) : m_wdata(&wdata), m_topk(k) {} + block_max_wand_query(Index const &index, WandType const &wdata, uint64_t k) + : m_index(index), m_wdata(&wdata), m_topk(k) {} - template - uint64_t operator()(Index const &index, term_id_vec const &terms) { + uint64_t operator()(term_id_vec const &terms) { m_topk.clear(); if (terms.empty()) return 0; auto query_term_freqs = query_freqs(terms); - uint64_t num_docs = index.num_docs(); + uint64_t num_docs = m_index.num_docs(); typedef typename Index::document_enumerator enum_type; typedef typename WandType::wand_data_enumerator wdata_enum; @@ -30,7 +30,7 @@ struct block_max_wand_query { enums.reserve(query_term_freqs.size()); for (auto term : query_term_freqs) { - auto list = index[term.first]; + auto list = m_index[term.first]; auto w_enum = m_wdata->getenum(term.first); auto q_weight = scorer_type::query_term_weight(term.second, list.size(), num_docs); @@ -204,6 +204,7 @@ struct block_max_wand_query { topk_queue const &get_topk() const { return m_topk; } private: + Index const & m_index; WandType const *m_wdata; topk_queue m_topk; }; diff --git a/include/pisa/query/algorithm/maxscore_query.hpp b/include/pisa/query/algorithm/maxscore_query.hpp index 37f45c999..c7674fb1d 100644 --- a/include/pisa/query/algorithm/maxscore_query.hpp +++ b/include/pisa/query/algorithm/maxscore_query.hpp @@ -2,22 +2,21 @@ namespace pisa { -template +template struct maxscore_query { typedef bm25 scorer_type; - maxscore_query(WandType const &wdata, uint64_t k) : m_wdata(&wdata), m_topk(k) {} + maxscore_query(Index const &index, WandType const &wdata, uint64_t k) : m_index(index), m_wdata(&wdata), m_topk(k) {} - template - uint64_t operator()(Index const &index, term_id_vec const &terms) { + uint64_t operator()(term_id_vec const &terms) { m_topk.clear(); if (terms.empty()) return 0; auto query_term_freqs = query_freqs(terms); - uint64_t num_docs = index.num_docs(); + uint64_t num_docs = m_index.num_docs(); typedef typename Index::document_enumerator enum_type; struct scored_enum { enum_type docs_enum; @@ -29,7 +28,7 @@ struct maxscore_query { enums.reserve(query_term_freqs.size()); for (auto term : query_term_freqs) { - auto list = index[term.first]; + auto list = m_index[term.first]; auto q_weight = scorer_type::query_term_weight(term.second, list.size(), num_docs); auto max_weight = q_weight * m_wdata->max_term_weight(term.first); enums.push_back(scored_enum{std::move(list), q_weight, max_weight}); @@ -62,10 +61,10 @@ struct maxscore_query { }) ->docs_enum.docid(); - while (non_essential_lists < ordered_enums.size() && cur_doc < index.num_docs()) { + while (non_essential_lists < ordered_enums.size() && cur_doc < m_index.num_docs()) { float score = 0; float norm_len = m_wdata->norm_len(cur_doc); - uint64_t next_doc = index.num_docs(); + uint64_t next_doc = m_index.num_docs(); for (size_t i = non_essential_lists; i < ordered_enums.size(); ++i) { if (ordered_enums[i]->docs_enum.docid() == cur_doc) { score += @@ -109,6 +108,7 @@ struct maxscore_query { std::vector> const &topk() const { return m_topk.topk(); } private: + Index const & m_index; WandType const *m_wdata; topk_queue m_topk; }; diff --git a/include/pisa/query/algorithm/or_query.hpp b/include/pisa/query/algorithm/or_query.hpp index 452bed52e..55fba0f64 100644 --- a/include/pisa/query/algorithm/or_query.hpp +++ b/include/pisa/query/algorithm/or_query.hpp @@ -2,11 +2,12 @@ namespace pisa { -template +template struct or_query { - template - uint64_t operator()(Index const &index, term_id_vec terms) const { + or_query(Index const &index) : m_index(index) {} + + uint64_t operator()(term_id_vec terms) const { if (terms.empty()) return 0; remove_duplicate_terms(terms); @@ -16,7 +17,7 @@ struct or_query { enums.reserve(terms.size()); for (auto term : terms) { - enums.push_back(index[term]); + enums.push_back(m_index[term]); } uint64_t results = 0; @@ -27,9 +28,9 @@ struct or_query { }) ->docid(); - while (cur_doc < index.num_docs()) { + while (cur_doc < m_index.num_docs()) { results += 1; - uint64_t next_doc = index.num_docs(); + uint64_t next_doc = m_index.num_docs(); for (size_t i = 0; i < enums.size(); ++i) { if (enums[i].docid() == cur_doc) { if (with_freqs) { @@ -47,6 +48,9 @@ struct or_query { return results; } + + private: + Index const &m_index; }; } // namespace pisa \ No newline at end of file diff --git a/include/pisa/query/algorithm/ranked_and_query.hpp b/include/pisa/query/algorithm/ranked_and_query.hpp index 3d330cc15..ef9e7da68 100644 --- a/include/pisa/query/algorithm/ranked_and_query.hpp +++ b/include/pisa/query/algorithm/ranked_and_query.hpp @@ -2,15 +2,15 @@ namespace pisa { -template +template struct ranked_and_query { typedef bm25 scorer_type; - ranked_and_query(WandType const &wdata, uint64_t k) : m_wdata(&wdata), m_topk(k) {} + ranked_and_query(Index const &index, WandType const &wdata, uint64_t k) + : m_index(index), m_wdata(&wdata), m_topk(k) {} - template - uint64_t operator()(Index const &index, term_id_vec terms) { + uint64_t operator()(term_id_vec terms) { size_t results = 0; m_topk.clear(); if (terms.empty()) @@ -18,7 +18,7 @@ struct ranked_and_query { auto query_term_freqs = query_freqs(terms); - uint64_t num_docs = index.num_docs(); + uint64_t num_docs = m_index.num_docs(); typedef typename Index::document_enumerator enum_type; struct scored_enum { enum_type docs_enum; @@ -29,7 +29,7 @@ struct ranked_and_query { enums.reserve(query_term_freqs.size()); for (auto term : query_term_freqs) { - auto list = index[term.first]; + auto list = m_index[term.first]; auto q_weight = scorer_type::query_term_weight(term.second, list.size(), num_docs); enums.push_back(scored_enum{std::move(list), q_weight}); } @@ -41,7 +41,7 @@ struct ranked_and_query { uint64_t candidate = enums[0].docs_enum.docid(); size_t i = 1; - while (candidate < index.num_docs()) { + while (candidate < m_index.num_docs()) { for (; i < enums.size(); ++i) { enums[i].docs_enum.next_geq(candidate); if (enums[i].docs_enum.docid() != candidate) { @@ -80,6 +80,7 @@ struct ranked_and_query { topk_queue &get_topk() { return m_topk; } private: + Index const & m_index; WandType const *m_wdata; topk_queue m_topk; }; diff --git a/include/pisa/query/algorithm/ranked_or_query.hpp b/include/pisa/query/algorithm/ranked_or_query.hpp index a5ab88216..dbf9fb06e 100644 --- a/include/pisa/query/algorithm/ranked_or_query.hpp +++ b/include/pisa/query/algorithm/ranked_or_query.hpp @@ -2,22 +2,22 @@ namespace pisa { -template +template struct ranked_or_query { typedef bm25 scorer_type; - ranked_or_query(WandType const &wdata, uint64_t k) : m_wdata(&wdata), m_topk(k) {} + ranked_or_query(Index const &index, WandType const &wdata, uint64_t k) + : m_index(index), m_wdata(&wdata), m_topk(k) {} - template - uint64_t operator()(Index const &index, term_id_vec terms) { + uint64_t operator()(term_id_vec terms) { m_topk.clear(); if (terms.empty()) return 0; auto query_term_freqs = query_freqs(terms); - uint64_t num_docs = index.num_docs(); + uint64_t num_docs = m_index.num_docs(); typedef typename Index::document_enumerator enum_type; struct scored_enum { enum_type docs_enum; @@ -28,7 +28,7 @@ struct ranked_or_query { enums.reserve(query_term_freqs.size()); for (auto term : query_term_freqs) { - auto list = index[term.first]; + auto list = m_index[term.first]; auto q_weight = scorer_type::query_term_weight(term.second, list.size(), num_docs); enums.push_back(scored_enum{std::move(list), q_weight}); } @@ -41,10 +41,10 @@ struct ranked_or_query { }) ->docs_enum.docid(); - while (cur_doc < index.num_docs()) { + while (cur_doc < m_index.num_docs()) { float score = 0; float norm_len = m_wdata->norm_len(cur_doc); - uint64_t next_doc = index.num_docs(); + uint64_t next_doc = m_index.num_docs(); for (size_t i = 0; i < enums.size(); ++i) { if (enums[i].docs_enum.docid() == cur_doc) { score += enums[i].q_weight * @@ -67,6 +67,7 @@ struct ranked_or_query { std::vector> const &topk() const { return m_topk.topk(); } private: + Index const & m_index; WandType const *m_wdata; topk_queue m_topk; }; diff --git a/include/pisa/query/algorithm/ranked_or_taat_query.hpp b/include/pisa/query/algorithm/ranked_or_taat_query.hpp new file mode 100644 index 000000000..1543d63ac --- /dev/null +++ b/include/pisa/query/algorithm/ranked_or_taat_query.hpp @@ -0,0 +1,53 @@ +#pragma once + +#include "util/intrinsics.hpp" +#include "topk_queue.hpp" + +#include "accumulator/simple_accumulator.hpp" + +namespace pisa { + +template +class ranked_or_taat_query { + using score_function_type = Score_Function; + + public: + ranked_or_taat_query(Index const &index, WandType const &wdata, uint64_t k) + : m_index(index), m_wdata(wdata), m_topk(k), m_accumulators(index.num_docs()) {} + + uint64_t operator()(term_id_vec terms) { + auto [cursors, score_functions] = query::cursors_with_scores(m_index, m_wdata, terms); + m_topk.clear(); + if (cursors.empty()) { + return 0; + } + m_accumulators.init(); + for (uint32_t term = 0; term < cursors.size(); ++term) { + auto cursor = cursors[term]; + const auto score = score_functions[term]; + for (; cursor.docid() < m_accumulators.size(); cursor.next()) { + m_accumulators.accumulate(cursor.docid(), score(cursor.docid(), cursor.freq())); + } + } + m_accumulators.aggregate(m_topk); + m_topk.finalize(); + return m_topk.topk().size(); + } + + std::vector> const &topk() const { return m_topk.topk(); } + + private: + Index const & m_index; + WandType const & m_wdata; + topk_queue m_topk; + Acc m_accumulators; +}; + +template +[[nodiscard]] auto make_ranked_or_taat_query(Index const & index, + WandType const &wdata, + uint64_t k) { + return ranked_or_taat_query(index, wdata, k); +} + +}; // namespace pisa diff --git a/include/pisa/query/algorithm/wand_query.hpp b/include/pisa/query/algorithm/wand_query.hpp index 2194cdde2..00839493e 100644 --- a/include/pisa/query/algorithm/wand_query.hpp +++ b/include/pisa/query/algorithm/wand_query.hpp @@ -2,22 +2,22 @@ namespace pisa { -template +template struct wand_query { typedef bm25 scorer_type; - wand_query(WandType const &wdata, uint64_t k) : m_wdata(&wdata), m_topk(k) {} + wand_query(Index const &index, WandType const &wdata, uint64_t k) + : m_index(index), m_wdata(&wdata), m_topk(k) {} - template - uint64_t operator()(Index const &index, term_id_vec const &terms) { + uint64_t operator()(term_id_vec const &terms) { m_topk.clear(); if (terms.empty()) return 0; auto query_term_freqs = query_freqs(terms); - uint64_t num_docs = index.num_docs(); + uint64_t num_docs = m_index.num_docs(); typedef typename Index::document_enumerator enum_type; struct scored_enum { enum_type docs_enum; @@ -29,7 +29,7 @@ struct wand_query { enums.reserve(query_term_freqs.size()); for (auto term : query_term_freqs) { - auto list = index[term.first]; + auto list = m_index[term.first]; auto q_weight = scorer_type::query_term_weight(term.second, list.size(), num_docs); auto max_weight = q_weight * m_wdata->max_term_weight(term.first); @@ -114,6 +114,7 @@ struct wand_query { std::vector> const &topk() const { return m_topk.topk(); } private: + Index const & m_index; WandType const *m_wdata; topk_queue m_topk; }; diff --git a/include/pisa/query/queries.hpp b/include/pisa/query/queries.hpp index 6da30f2ab..5554e4c2c 100644 --- a/include/pisa/query/queries.hpp +++ b/include/pisa/query/queries.hpp @@ -51,6 +51,46 @@ term_freq_vec query_freqs(term_id_vec terms) { return query_term_freqs; } +template +struct Score_Function { + float query_weight; + std::reference_wrapper wdata; + + [[nodiscard]] auto operator()(uint32_t doc, uint32_t freq) const -> float { + return query_weight * Scorer::doc_term_weight(freq, wdata.get().norm_len(doc)); + } +}; + +// TODO: These are functions common to query processing in general. +// They should be moved out of this file. +namespace query { + +template +[[nodiscard]] auto cursors_with_scores(Index const& index, WandType const &wdata, term_id_vec terms) +{ + // TODO(michal): parametrize scorer_type; didn't do that because this might mean some more + // complex refactoring I want to avoid for now. + using scorer_type = bm25; + using cursor_type = typename Index::document_enumerator; + using score_function_type = Score_Function; + + auto query_term_freqs = query_freqs(terms); + std::vector cursors; + std::vector score_functions; + cursors.reserve(query_term_freqs.size()); + score_functions.reserve(query_term_freqs.size()); + + for (auto term : query_term_freqs) { + auto list = index[term.first]; + uint64_t num_docs = index.num_docs(); + auto q_weight = scorer_type::query_term_weight(term.second, list.size(), num_docs); + cursors.push_back(std::move(list)); + score_functions.push_back({q_weight, std::cref(wdata)}); + } + return std::make_pair(cursors, score_functions); +} + +} // namespace query } // namespace pisa #include "algorithm/and_query.hpp" @@ -60,4 +100,5 @@ term_freq_vec query_freqs(term_id_vec terms) { #include "algorithm/or_query.hpp" #include "algorithm/ranked_and_query.hpp" #include "algorithm/ranked_or_query.hpp" -#include "algorithm/wand_query.hpp" \ No newline at end of file +#include "algorithm/wand_query.hpp" +#include "algorithm/ranked_or_taat_query.hpp" diff --git a/src/profile_queries.cpp b/src/profile_queries.cpp index 92a8dcba9..bd5e64ad5 100644 --- a/src/profile_queries.cpp +++ b/src/profile_queries.cpp @@ -15,9 +15,8 @@ #include "query/queries.hpp" #include "util/util.hpp" -template -void op_profile(IndexType const& index, - QueryOperator const& query_op, +template +void op_profile(QueryOperator const& query_op, std::vector const& queries) { using namespace pisa; @@ -35,7 +34,7 @@ void op_profile(IndexType const& index, spdlog::info("{} queries processed", i); } - query_op_copy(index, queries[i]); + query_op_copy(queries[i]); } }); } @@ -88,13 +87,13 @@ void profile(const std::string index_filename, for (auto const& t: query_types) { spdlog::info("Query type: {}", t); if (t == "and") { - op_profile(index, and_query(), queries); + op_profile(and_query::type, false>(index), queries); } else if (t == "ranked_and" && wand_data_filename) { - op_profile(index, ranked_and_query(wdata, 10), queries); + op_profile(ranked_and_query::type, WandType>(index, wdata, 10), queries); } else if (t == "wand" && wand_data_filename) { - op_profile(index, wand_query(wdata, 10), queries); + op_profile(wand_query::type, WandType>(index, wdata, 10), queries); } else if (t == "maxscore" && wand_data_filename) { - op_profile(index, maxscore_query(wdata, 10), queries); + op_profile(maxscore_query::type, WandType>(index, wdata, 10), queries); } else { spdlog::error("Unsupported query type: {}", t); } diff --git a/src/queries.cpp b/src/queries.cpp index 8ddeda93c..827683171 100644 --- a/src/queries.cpp +++ b/src/queries.cpp @@ -119,33 +119,25 @@ void perftest(const std::string &index_filename, spdlog::info("Query type: {}", t); std::function query_fun; if (t == "and") { - query_fun = [&](term_id_vec query) { return and_query()(index, query); }; + query_fun = and_query(index); } else if (t == "and_freq") { - query_fun = [&](term_id_vec query) { return and_query()(index, query); }; + query_fun = and_query(index); } else if (t == "or") { - query_fun = [&](term_id_vec query) { return or_query()(index, query); }; + query_fun = or_query(index); } else if (t == "or_freq") { - query_fun = [&](term_id_vec query) { return or_query()(index, query); }; + query_fun = or_query(index); } else if (t == "wand" && wand_data_filename) { - query_fun = [&](term_id_vec query) { - return wand_query(wdata, k)(index, query); - }; + query_fun = wand_query(index, wdata, k); } else if (t == "block_max_wand" && wand_data_filename) { - query_fun = [&](term_id_vec query) { - return block_max_wand_query(wdata, k)(index, query); - }; + query_fun =block_max_wand_query(index, wdata, k); } else if (t == "block_max_maxscore" && wand_data_filename) { - query_fun = [&](term_id_vec query) { - return block_max_maxscore_query(wdata, k)(index, query); - }; + query_fun = block_max_maxscore_query(index, wdata, k); } else if (t == "ranked_or" && wand_data_filename) { - query_fun = [&](term_id_vec query) { - return ranked_or_query(wdata, k)(index, query); - }; + query_fun = ranked_or_query(index, wdata, k); } else if (t == "maxscore" && wand_data_filename) { - query_fun = [&](term_id_vec query) { - return maxscore_query(wdata, k)(index, query); - }; + query_fun = maxscore_query(index, wdata, k); + } else if (t == "ranked_or_taat" && wand_data_filename) { + query_fun = pisa::make_ranked_or_taat_query(index, wdata, k); } else { spdlog::error("Unsupported query type: {}", t); break; diff --git a/src/thresholds.cpp b/src/thresholds.cpp index d14f26a06..72d0a3d6b 100644 --- a/src/thresholds.cpp +++ b/src/thresholds.cpp @@ -44,9 +44,9 @@ void thresholds(const std::string & index_filename, mapper::map(wdata, md, mapper::map_flags::warmup); } - wand_query query_func(wdata, k); + wand_query query_func(index, wdata, k); for (auto const &query : queries) { - query_func(index, query); + query_func(query); auto results = query_func.topk(); float threshold = 0.0; if (results.size() == k) { diff --git a/test/test_bmw_queries.cpp b/test/test_bmw_queries.cpp index 67b5c82fe..a229e4b21 100644 --- a/test/test_bmw_queries.cpp +++ b/test/test_bmw_queries.cpp @@ -57,11 +57,11 @@ struct index_initialization { template void test_against_wand(QueryOp &op_q) const { - wand_query or_q(wdata, 10); + wand_query or_q(index, wdata, 10); for (auto const &q : queries) { - or_q(index, q); - op_q(index, q); + or_q(q); + op_q(q); REQUIRE(or_q.topk().size() == op_q.topk().size()); for (size_t i = 0; i < or_q.topk().size(); ++i) { @@ -76,9 +76,9 @@ struct index_initialization { } // namespace pisa TEST_CASE_METHOD(pisa::test::index_initialization, "block_max_wand") { - pisa::block_max_wand_query block_max_wand_q(wdata, 10); - pisa::block_max_wand_query block_max_wand_uniform_q(wdata_uniform, 10); - pisa::block_max_wand_query block_max_wand_fixed_q(wdata_fixed, 10); + pisa::block_max_wand_query block_max_wand_q(index, wdata, 10); + pisa::block_max_wand_query block_max_wand_uniform_q(index, wdata_uniform, 10); + pisa::block_max_wand_query block_max_wand_fixed_q(index, wdata_fixed, 10); test_against_wand(block_max_wand_uniform_q); test_against_wand(block_max_wand_q); test_against_wand(block_max_wand_fixed_q); diff --git a/test/test_ranked_queries.cpp b/test/test_ranked_queries.cpp index 9ea008366..c3a39592b 100644 --- a/test/test_ranked_queries.cpp +++ b/test/test_ranked_queries.cpp @@ -49,11 +49,11 @@ namespace pisa { namespace test { template void test_against_or(QueryOp &op_q) const { - ranked_or_query or_q(wdata, 10); + ranked_or_query or_q(index, wdata, 10); - for (auto const &q : queries) { - or_q(index, q); - op_q(index, q); + for (auto const& q: queries) { + or_q(q); + op_q(q); REQUIRE(or_q.topk().size() == op_q.topk().size()); for (size_t i = 0; i < or_q.topk().size(); ++i) { REQUIRE(or_q.topk()[i].first == @@ -64,12 +64,12 @@ namespace pisa { namespace test { void test_k_size() const { - ranked_or_query or_10(wdata, 10); - ranked_or_query or_1(wdata, 1); + ranked_or_query or_10(index, wdata, 10); + ranked_or_query or_1(index, wdata, 1); for (auto const &q : queries) { - or_10(index, q); - or_1(index, q); + or_10(q); + or_1(q); if (not or_10.topk().empty()) { REQUIRE(not or_1.topk().empty()); REQUIRE(or_1.topk().front().first == @@ -83,22 +83,29 @@ namespace pisa { namespace test { TEST_CASE_METHOD(pisa::test::index_initialization, "wand") { - pisa::wand_query wand_q(wdata, 10); + pisa::wand_query wand_q(index, wdata, 10); test_against_or(wand_q); } TEST_CASE_METHOD(pisa::test::index_initialization, "maxscore") { - pisa::maxscore_query maxscore_q(wdata, 10); + pisa::maxscore_query maxscore_q(index, wdata, 10); test_against_or(maxscore_q); } TEST_CASE_METHOD(pisa::test::index_initialization, "block_max_maxscore") { - pisa::block_max_maxscore_query bmm_q(wdata, 10); + pisa::block_max_maxscore_query bmm_q(index, wdata, 10); test_against_or(bmm_q); } +TEST_CASE_METHOD(pisa::test::index_initialization, "ranked_or_taat") +{ + pisa::ranked_or_taat_query taat_q( + index, wdata, 10); + test_against_or(taat_q); +} + /// Issue #26 https://github.com/pisa-engine/pisa/issues/26 TEST_CASE_METHOD(pisa::test::index_initialization, "topk_size_ranked_or") {