From 12d5fb62edf71b43a8d6ed2e80345fe019b8e9f7 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Sun, 26 Apr 2020 19:45:59 +0000 Subject: [PATCH 01/12] Query contaier --- include/pisa/query.hpp | 99 +++++++++++++++++++++++++ src/query.cpp | 161 +++++++++++++++++++++++++++++++++++++++++ test/test_query.cpp | 84 +++++++++++++++++++++ 3 files changed, 344 insertions(+) create mode 100644 include/pisa/query.hpp create mode 100644 src/query.cpp create mode 100644 test/test_query.cpp diff --git a/include/pisa/query.hpp b/include/pisa/query.hpp new file mode 100644 index 000000000..7f481ee08 --- /dev/null +++ b/include/pisa/query.hpp @@ -0,0 +1,99 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace pisa { + +struct QueryContainerInner; + +struct ParsedTerm { + std::uint32_t id; + std::string term; +}; + +using TermProcessorFn = std::function(std::string)>; +using ParseFn = std::function(std::string const&)>; + +class QueryContainer; + +/// Query is a special container that maintains important invariants, such as sorted term IDs, +/// and also has some additional data, like term weights, etc. +class Query { + public: + explicit Query(QueryContainer const& data); + + [[nodiscard]] auto term_ids() const -> gsl::span; + [[nodiscard]] auto threshold() const -> std::optional; + + private: + std::optional m_threshold{}; + std::vector m_term_ids{}; +}; + +class QueryContainer { + public: + QueryContainer(QueryContainer const&); + QueryContainer(QueryContainer&&) noexcept; + QueryContainer& operator=(QueryContainer const&); + QueryContainer& operator=(QueryContainer&&) noexcept; + ~QueryContainer(); + + /// Constructs a query from a raw string. + [[nodiscard]] static auto raw(std::string query_string) -> QueryContainer; + + /// Constructs a query from a list of terms. + /// + /// \param terms List of terms + /// \param term_processor Function executed for each term before stroring them, + /// e.g., stemming or filtering. This function returns + /// `std::optional`, and all `std::nullopt` values + /// will be filtered out. + [[nodiscard]] static auto + from_terms(std::vector terms, std::optional term_processor) + -> QueryContainer; + + /// Constructs a query from a list of term IDs. + [[nodiscard]] static auto from_term_ids(std::vector term_ids) -> QueryContainer; + + // Accessors + + [[nodiscard]] auto string() const noexcept -> std::optional const&; + [[nodiscard]] auto terms() const noexcept -> std::optional> const&; + [[nodiscard]] auto term_ids() const noexcept -> std::optional> const&; + [[nodiscard]] auto threshold() const noexcept -> std::optional const&; + + /// Sets the raw string. + [[nodiscard]] auto string(std::string) -> QueryContainer&; + + /// Sets processed terms. + /// + /// NOTE: If the intent is to parse the query, use `parse` method instead. + /// This method is intended to be used when loading a query from JSON or another + /// external representation. + /// + /// \throws std::domain_error when term IDs are set but the lengths don't match + auto processed_terms(std::vector terms) -> QueryContainer&; + + /// Parses the raw query with the given parser. + /// + /// \throws std::domain_error when raw string is not set + auto parse(ParseFn parse_fn) -> QueryContainer&; + + /// Sets the query score threshold. + auto threshold(float score) -> QueryContainer&; + + /// Returns a query ready to be used for retrieval. + [[nodiscard]] auto query() const -> Query; + + private: + QueryContainer(); + std::unique_ptr m_data; +}; + +} // namespace pisa diff --git a/src/query.cpp b/src/query.cpp new file mode 100644 index 000000000..2838d9dcb --- /dev/null +++ b/src/query.cpp @@ -0,0 +1,161 @@ +#include "query.hpp" + +#include + +#include + +namespace pisa { + +Query::Query(QueryContainer const& data) : m_threshold(data.threshold()) +{ + if (auto term_ids = data.term_ids(); term_ids) { + m_term_ids = *term_ids; + std::sort(m_term_ids.begin(), m_term_ids.end()); + auto last = std::unique(m_term_ids.begin(), m_term_ids.end()); + m_term_ids.erase(last, m_term_ids.end()); + } + throw std::domain_error("Query not parsed."); +} + +auto Query::term_ids() const -> gsl::span +{ + return gsl::span(m_term_ids); +} + +auto Query::threshold() const -> std::optional +{ + return m_threshold; +} + +struct QueryContainerInner { + std::optional query_string; + std::optional> processed_terms; + std::optional> term_ids; + std::optional threshold; +}; + +QueryContainer::QueryContainer() : m_data(std::make_unique()) {} + +QueryContainer::QueryContainer(QueryContainer const& other) + : m_data(std::make_unique(*other.m_data)) +{} +QueryContainer::QueryContainer(QueryContainer&&) noexcept = default; +QueryContainer& QueryContainer::operator=(QueryContainer const& other) +{ + this->m_data = std::make_unique(*other.m_data); + return *this; +} +QueryContainer& QueryContainer::operator=(QueryContainer&&) noexcept = default; +QueryContainer::~QueryContainer() = default; + +auto QueryContainer::raw(std::string query_string) -> QueryContainer +{ + QueryContainer query; + query.m_data->query_string = std::move(query_string); + return query; +} + +auto QueryContainer::from_terms( + std::vector terms, std::optional term_processor) -> QueryContainer +{ + QueryContainer query; + query.m_data->processed_terms = std::vector{}; + auto& processed_terms = *query.m_data->processed_terms; + for (auto&& term: terms) { + if (term_processor) { + auto fn = *term_processor; + if (auto processed = fn(std::move(term)); processed) { + processed_terms.push_back(std::move(*processed)); + } + } else { + processed_terms.push_back(std::move(term)); + } + } + return query; +} + +auto QueryContainer::from_term_ids(std::vector term_ids) -> QueryContainer +{ + QueryContainer query; + query.m_data->term_ids = std::move(term_ids); + return query; +} + +auto QueryContainer::string() const noexcept -> std::optional const& +{ + return m_data->query_string; +} +auto QueryContainer::terms() const noexcept -> std::optional> const& +{ + return m_data->processed_terms; +} + +auto QueryContainer::term_ids() const noexcept -> std::optional> const& +{ + return m_data->term_ids; +} + +auto QueryContainer::threshold() const noexcept -> std::optional const& +{ + return m_data->threshold; +} + +auto QueryContainer::string(std::string raw_query) -> QueryContainer& +{ + m_data->query_string = std::move(raw_query); + return *this; +} + +auto QueryContainer::processed_terms(std::vector terms) -> QueryContainer& +{ + if (auto&& term_ids = m_data->term_ids; term_ids.has_value() && term_ids->size() != terms.size()) { + throw std::domain_error(fmt::format( + "Number of terms ({}) must match number of term IDs ({})", + fmt::join(terms, ", "), + fmt::join(*term_ids, ", "))); + } + m_data->processed_terms = std::move(terms); + return *this; +} + +// auto QueryContainer::term_ids(std::vector term_ids) -> Query& +//{ +// if (auto&& terms = m_data->processed_terms; +// terms.has_value() && terms->size() != term_ids.size()) { +// throw std::domain_error(fmt::format( +// "Number of terms ({}) must match number of term IDs ({})", +// fmt::join(*terms, ", "), +// fmt::join(term_ids, ", "))); +// } +// m_data->term_ids = std::move(term_ids); +// return *this; +//} + +auto QueryContainer::parse(ParseFn parse_fn) -> QueryContainer& +{ + if (not m_data->query_string) { + throw std::domain_error("Cannot parse, query string not set"); + } + auto parsed_terms = parse_fn(*m_data->query_string); + std::vector processed_terms; + std::vector term_ids; + for (auto&& term: parsed_terms) { + processed_terms.push_back(std::move(term.term)); + term_ids.push_back(term.id); + } + m_data->term_ids = std::move(term_ids); + return *this; +} + +auto QueryContainer::threshold(float score) -> QueryContainer& +{ + m_data->threshold = score; + return *this; +} + +auto QueryContainer::query() const -> Query +{ + return Query(*this); +} + +} // namespace pisa diff --git a/test/test_query.cpp b/test/test_query.cpp new file mode 100644 index 000000000..e65ff1283 --- /dev/null +++ b/test/test_query.cpp @@ -0,0 +1,84 @@ +#define CATCH_CONFIG_MAIN + +#include + +#include + +#include "query.hpp" + +using pisa::QueryContainer; + +TEST_CASE("Construct from raw string") +{ + auto raw_query = "brooklyn tea house"; + auto query = QueryContainer::raw(raw_query); + REQUIRE(*query.string() == raw_query); +} + +TEST_CASE("Construct from terms") +{ + std::vector terms{"brooklyn", "tea", "house"}; + auto query = QueryContainer::from_terms(terms, std::nullopt); + REQUIRE(*query.terms() == std::vector{"brooklyn", "tea", "house"}); +} + +TEST_CASE("Construct from terms with processor") +{ + std::vector terms{"brooklyn", "tea", "house"}; + auto proc = [](std::string term) -> std::optional { + if (term.size() > 3) { + return term.substr(0, 4); + } + return std::nullopt; + }; + auto query = QueryContainer::from_terms(terms, proc); + REQUIRE(*query.terms() == std::vector{"broo", "hous"}); +} + +TEST_CASE("Construct from term IDs") +{ + std::vector term_ids{1, 0, 3}; + auto query = QueryContainer::from_term_ids(term_ids); + REQUIRE(*query.term_ids() == std::vector{1, 0, 3}); +} + +TEST_CASE("Set processed terms") +{ + std::vector term_ids{1, 0, 3}; + auto query = QueryContainer::from_term_ids(term_ids); + query.processed_terms(std::vector{"brooklyn", "tea", "house"}); + REQUIRE(*query.terms() == std::vector{"brooklyn", "tea", "house"}); + REQUIRE_THROWS_AS( + query.processed_terms(std::vector{"tea", "house"}), std::domain_error); +} + +TEST_CASE("Parse query") +{ + auto raw_query = "brooklyn tea house brooklyn"; + auto query = QueryContainer::raw(raw_query); + std::vector lexicon{"house", "brooklyn"}; + auto term_proc = [](std::string term) -> std::optional { return term; }; + query.parse([&](auto&& q) { + std::istringstream is(q); + std::string term; + std::vector parsed_terms; + while (is >> term) { + if (auto t = term_proc(term); t) { + if (auto pos = std::find(lexicon.begin(), lexicon.end(), *t); pos != lexicon.end()) { + auto id = static_cast(std::distance(lexicon.begin(), pos)); + parsed_terms.push_back(pisa::ParsedTerm{id, *t}); + } + } + } + return parsed_terms; + }); + REQUIRE(*query.term_ids() == std::vector{1, 0, 1}); +} + +TEST_CASE("Parsing throws without raw query") +{ + std::vector term_ids{1, 0, 3}; + auto query = QueryContainer::from_term_ids(term_ids); + REQUIRE_THROWS_AS( + query.parse([](auto&& str) { return std::vector{}; }), std::domain_error); +} From 1d3aa8851d77eeef4f4512082e3725878e9eba14 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Sun, 26 Apr 2020 23:12:42 +0000 Subject: [PATCH 02/12] Query container parsing --- CMakeLists.txt | 1 + include/pisa/query.hpp | 32 +++++---- src/query.cpp | 125 ++++++++++++++++++++++++++--------- test/test_query.cpp | 79 +++++++++++++++++++--- tools/CMakeLists.txt | 6 ++ tools/app.hpp | 15 +++++ tools/filter_queries.cpp | 138 +++++++++++++++++++++++++++++++++++++++ 7 files changed, 344 insertions(+), 52 deletions(-) create mode 100644 tools/filter_queries.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 86e0f2460..769fee467 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -105,6 +105,7 @@ target_link_libraries(pisa PUBLIC # TODO(michal): are there any of these we can spdlog fmt::fmt range-v3 + nlohmann_json::nlohmann_json ) target_include_directories(pisa PUBLIC external) diff --git a/include/pisa/query.hpp b/include/pisa/query.hpp index 7f481ee08..1b39e3d51 100644 --- a/include/pisa/query.hpp +++ b/include/pisa/query.hpp @@ -24,9 +24,9 @@ class QueryContainer; /// Query is a special container that maintains important invariants, such as sorted term IDs, /// and also has some additional data, like term weights, etc. -class Query { +class QueryRequest { public: - explicit Query(QueryContainer const& data); + explicit QueryRequest(QueryContainer const& data); [[nodiscard]] auto term_ids() const -> gsl::span; [[nodiscard]] auto threshold() const -> std::optional; @@ -61,8 +61,25 @@ class QueryContainer { /// Constructs a query from a list of term IDs. [[nodiscard]] static auto from_term_ids(std::vector term_ids) -> QueryContainer; + /// Constructs a query from a JSON object. + [[nodiscard]] static auto from_json(std::string_view json_string) -> QueryContainer; + + [[nodiscard]] auto to_json() const -> std::string; + + /// Constructs a query from a colon-separated format: + /// + /// ``` + /// id:raw query string + /// ``` + /// or + /// ``` + /// raw query string + /// ``` + [[nodiscard]] static auto from_colon_format(std::string_view line) -> QueryContainer; + // Accessors + [[nodiscard]] auto id() const noexcept -> std::optional const&; [[nodiscard]] auto string() const noexcept -> std::optional const&; [[nodiscard]] auto terms() const noexcept -> std::optional> const&; [[nodiscard]] auto term_ids() const noexcept -> std::optional> const&; @@ -71,15 +88,6 @@ class QueryContainer { /// Sets the raw string. [[nodiscard]] auto string(std::string) -> QueryContainer&; - /// Sets processed terms. - /// - /// NOTE: If the intent is to parse the query, use `parse` method instead. - /// This method is intended to be used when loading a query from JSON or another - /// external representation. - /// - /// \throws std::domain_error when term IDs are set but the lengths don't match - auto processed_terms(std::vector terms) -> QueryContainer&; - /// Parses the raw query with the given parser. /// /// \throws std::domain_error when raw string is not set @@ -89,7 +97,7 @@ class QueryContainer { auto threshold(float score) -> QueryContainer&; /// Returns a query ready to be used for retrieval. - [[nodiscard]] auto query() const -> Query; + [[nodiscard]] auto query() const -> QueryRequest; private: QueryContainer(); diff --git a/src/query.cpp b/src/query.cpp index 2838d9dcb..39dedc846 100644 --- a/src/query.cpp +++ b/src/query.cpp @@ -3,10 +3,11 @@ #include #include +#include namespace pisa { -Query::Query(QueryContainer const& data) : m_threshold(data.threshold()) +QueryRequest::QueryRequest(QueryContainer const& data) : m_threshold(data.threshold()) { if (auto term_ids = data.term_ids(); term_ids) { m_term_ids = *term_ids; @@ -17,17 +18,18 @@ Query::Query(QueryContainer const& data) : m_threshold(data.threshold()) throw std::domain_error("Query not parsed."); } -auto Query::term_ids() const -> gsl::span +auto QueryRequest::term_ids() const -> gsl::span { return gsl::span(m_term_ids); } -auto Query::threshold() const -> std::optional +auto QueryRequest::threshold() const -> std::optional { return m_threshold; } struct QueryContainerInner { + std::optional id; std::optional query_string; std::optional> processed_terms; std::optional> term_ids; @@ -81,6 +83,10 @@ auto QueryContainer::from_term_ids(std::vector term_ids) -> Query return query; } +auto QueryContainer::id() const noexcept -> std::optional const& +{ + return m_data->id; +} auto QueryContainer::string() const noexcept -> std::optional const& { return m_data->query_string; @@ -106,31 +112,6 @@ auto QueryContainer::string(std::string raw_query) -> QueryContainer& return *this; } -auto QueryContainer::processed_terms(std::vector terms) -> QueryContainer& -{ - if (auto&& term_ids = m_data->term_ids; term_ids.has_value() && term_ids->size() != terms.size()) { - throw std::domain_error(fmt::format( - "Number of terms ({}) must match number of term IDs ({})", - fmt::join(terms, ", "), - fmt::join(*term_ids, ", "))); - } - m_data->processed_terms = std::move(terms); - return *this; -} - -// auto QueryContainer::term_ids(std::vector term_ids) -> Query& -//{ -// if (auto&& terms = m_data->processed_terms; -// terms.has_value() && terms->size() != term_ids.size()) { -// throw std::domain_error(fmt::format( -// "Number of terms ({}) must match number of term IDs ({})", -// fmt::join(*terms, ", "), -// fmt::join(term_ids, ", "))); -// } -// m_data->term_ids = std::move(term_ids); -// return *this; -//} - auto QueryContainer::parse(ParseFn parse_fn) -> QueryContainer& { if (not m_data->query_string) { @@ -153,9 +134,93 @@ auto QueryContainer::threshold(float score) -> QueryContainer& return *this; } -auto QueryContainer::query() const -> Query +auto QueryContainer::query() const -> QueryRequest { - return Query(*this); + return QueryRequest(*this); +} + +template +[[nodiscard]] auto get(nlohmann::json const& node, std::string_view field) -> std::optional +{ + if (auto pos = node.find(field); pos != node.end()) { + try { + return std::make_optional(pos->get()); + } catch (nlohmann::detail::exception const& err) { + throw std::runtime_error(fmt::format("Requested field {} is of wrong type", field)); + } + } + return std::optional{}; +} + +auto QueryContainer::from_json(std::string_view json_string) -> QueryContainer +{ + try { + auto json = nlohmann::json::parse(json_string); + QueryContainer query; + QueryContainerInner& data = *query.m_data; + bool at_least_one_required = false; + if (auto id = get(json, "id"); id) { + data.id = std::move(id); + } + if (auto raw = get(json, "query"); raw) { + data.query_string = std::move(raw); + at_least_one_required = true; + } + if (auto terms = get>(json, "terms"); terms) { + data.processed_terms = std::move(terms); + at_least_one_required = true; + } + if (auto term_ids = get>(json, "term_ids"); term_ids) { + data.term_ids = std::move(term_ids); + at_least_one_required = true; + } + if (auto threshold = get(json, "threshold"); threshold) { + data.threshold = threshold; + } + if (not at_least_one_required) { + throw std::invalid_argument(fmt::format( + "JSON must have either raw query, terms, or term IDs: {}", json_string)); + } + return query; + } catch (nlohmann::detail::exception const& err) { + throw std::runtime_error( + fmt::format("Failed to parse JSON: `{}`: {}", json_string, err.what())); + } +} + +auto QueryContainer::to_json() const -> std::string +{ + nlohmann::json json; + if (auto id = m_data->id; id) { + json["id"] = *id; + } + if (auto raw = m_data->query_string; raw) { + json["query"] = *raw; + } + if (auto terms = m_data->processed_terms; terms) { + json["terms"] = *terms; + } + if (auto term_ids = m_data->term_ids; term_ids) { + json["term_ids"] = *term_ids; + } + if (auto threshold = m_data->threshold; threshold) { + json["threshold"] = *threshold; + } + return json.dump(); +} + +auto QueryContainer::from_colon_format(std::string_view line) -> QueryContainer +{ + auto pos = std::find(line.begin(), line.end(), ':'); + QueryContainer query; + QueryContainerInner& data = *query.m_data; + if (pos == line.end()) { + data.query_string = std::string(line); + } else { + data.id = std::string(line.begin(), pos); + data.query_string = std::string(std::next(pos), line.end()); + } + return query; } } // namespace pisa diff --git a/test/test_query.cpp b/test/test_query.cpp index e65ff1283..e3e6128e7 100644 --- a/test/test_query.cpp +++ b/test/test_query.cpp @@ -42,16 +42,6 @@ TEST_CASE("Construct from term IDs") REQUIRE(*query.term_ids() == std::vector{1, 0, 3}); } -TEST_CASE("Set processed terms") -{ - std::vector term_ids{1, 0, 3}; - auto query = QueryContainer::from_term_ids(term_ids); - query.processed_terms(std::vector{"brooklyn", "tea", "house"}); - REQUIRE(*query.terms() == std::vector{"brooklyn", "tea", "house"}); - REQUIRE_THROWS_AS( - query.processed_terms(std::vector{"tea", "house"}), std::domain_error); -} - TEST_CASE("Parse query") { auto raw_query = "brooklyn tea house brooklyn"; @@ -82,3 +72,72 @@ TEST_CASE("Parsing throws without raw query") REQUIRE_THROWS_AS( query.parse([](auto&& str) { return std::vector{}; }), std::domain_error); } + +TEST_CASE("Parse query container from colon-delimited format") +{ + auto query = QueryContainer::from_colon_format(""); + REQUIRE(query.string()->empty()); + REQUIRE_FALSE(query.id()); + + query = QueryContainer::from_colon_format("brooklyn tea house"); + REQUIRE(*query.string() == "brooklyn tea house"); + REQUIRE_FALSE(query.id()); + + query = QueryContainer::from_colon_format("BTH:brooklyn tea house"); + REQUIRE(*query.string() == "brooklyn tea house"); + REQUIRE(*query.id() == "BTH"); + + query = QueryContainer::from_colon_format("BTH:"); + REQUIRE(query.string()->empty()); + REQUIRE(*query.id() == "BTH"); +} + +TEST_CASE("Parse query container from JSON") +{ + REQUIRE_THROWS_AS(QueryContainer::from_json(""), std::runtime_error); + REQUIRE_THROWS_AS(QueryContainer::from_json(R"({"id":"ID"})"), std::invalid_argument); + + auto query = QueryContainer::from_json(R"( + { + "id": "ID", + "query": "brooklyn tea house" + } + )"); + REQUIRE(*query.id() == "ID"); + REQUIRE(*query.string() == "brooklyn tea house"); + REQUIRE_FALSE(query.terms()); + REQUIRE_FALSE(query.term_ids()); + REQUIRE_FALSE(query.threshold()); + + query = QueryContainer::from_json(R"( + { + "term_ids": [1, 0, 3], + "terms": ["brooklyn", "tea", "house"], + "threshold": 10.8 + } + )"); + REQUIRE(*query.terms() == std::vector{"brooklyn", "tea", "house"}); + REQUIRE(*query.term_ids() == std::vector{1, 0, 3}); + REQUIRE(*query.threshold() == Approx(10.8)); + REQUIRE_FALSE(query.id()); + REQUIRE_FALSE(query.string()); + + REQUIRE_THROWS_AS(QueryContainer::from_json(R"({"terms":[1, 2]})"), std::runtime_error); +} + +TEST_CASE("Serialize query container to JSON") +{ + auto query = QueryContainer::from_json(R"( + { + "id": "ID", + "query": "brooklyn tea house", + "terms": ["brooklyn", "tea", "house"], + "term_ids": [1, 0, 3], + "threshold": 10.0 + } + )"); + auto serialized = query.to_json(); + REQUIRE( + serialized + == R"({"id":"ID","query":"brooklyn tea house","term_ids":[1,0,3],"terms":["brooklyn","tea","house"],"threshold":10.0})"); +} diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index d011d56ea..0bc424e99 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -136,3 +136,9 @@ target_link_libraries(reorder-docids pisa CLI11 ) + +add_executable(filter-queries filter_queries.cpp) +target_link_libraries(filter-queries + pisa + CLI11 +) diff --git a/tools/app.hpp b/tools/app.hpp index 3e723e9d8..16a4ac314 100644 --- a/tools/app.hpp +++ b/tools/app.hpp @@ -101,6 +101,21 @@ namespace arg { return q; } + [[nodiscard]] auto term_lexicon() const -> std::optional const& + { + return m_term_lexicon; + } + + [[nodiscard]] auto stemmer() const -> std::optional const& + { + return m_stemmer; + } + + [[nodiscard]] auto stop_words() const -> std::optional const& + { + return m_stop_words; + } + [[nodiscard]] auto k() const -> int { return m_k; } private: diff --git a/tools/filter_queries.cpp b/tools/filter_queries.cpp new file mode 100644 index 000000000..0b637d979 --- /dev/null +++ b/tools/filter_queries.cpp @@ -0,0 +1,138 @@ +#include + +#include +#include +#include + +#include "app.hpp" +#include "query.hpp" +#include "tokenizer.hpp" + +namespace arg = pisa::arg; +using pisa::QueryContainer; +using pisa::io::for_each_line; + +class TermProcessor { + private: + std::unordered_set stopwords; + + std::function(std::string const&)> m_to_id; + pisa::Stemmer_t m_stemmer; + + public: + TermProcessor( + std::optional const& terms_file, + std::optional const& stopwords_filename, + std::optional const& stemmer_type) + { + auto source = std::make_shared(terms_file->c_str()); + auto terms = pisa::Payload_Vector<>::from(*source); + + m_to_id = [source = std::move(source), terms](auto str) -> std::optional { + // Note: the lexicographical order of the terms matters. + auto pos = std::lower_bound(terms.begin(), terms.end(), std::string_view(str)); + if (*pos == std::string_view(str)) { + return std::distance(terms.begin(), pos); + } + return std::nullopt; + }; + + m_stemmer = pisa::term_processor(stemmer_type); + + if (stopwords_filename) { + std::ifstream is(*stopwords_filename); + pisa::io::for_each_line(is, [&](auto&& word) { + if (auto processed_term = m_to_id(std::move(word)); processed_term.has_value()) { + stopwords.insert(*processed_term); + } + }); + } + } + + [[nodiscard]] std::optional operator()(std::string token) + { + token = m_stemmer(token); + auto id = m_to_id(token); + if (not id) { + return std::nullopt; + } + if (is_stopword(*id)) { + return std::nullopt; + } + return pisa::ParsedTerm{*id, token}; + } + + [[nodiscard]] auto is_stopword(std::uint32_t const term) const -> bool + { + return stopwords.find(term) != stopwords.end(); + } +}; + +enum class Format { Json, Colon }; + +void filter_queries( + std::optional const& query_file, + std::optional const& term_lexicon, + std::optional const& stemmer, + std::size_t min_query_len, + std::size_t max_query_len) +{ + std::optional fmt{}; + auto parser = [term_processor = TermProcessor(term_lexicon, {}, stemmer)](auto query) mutable { + std::vector parsed_terms; + pisa::TermTokenizer tokenizer(query); + for (auto term_iter = tokenizer.begin(); term_iter != tokenizer.end(); ++term_iter) { + auto term = term_processor(*term_iter); + if (term) { + parsed_terms.push_back(std::move(*term)); + } + } + return parsed_terms; + }; + auto filter = [&](auto&& line) { + auto query = [&] { + if (fmt) { + if (*fmt == Format::Json) { + return QueryContainer::from_json(line); + } + return QueryContainer::from_colon_format(line); + } + try { + auto query = QueryContainer::from_json(line); + fmt = Format::Json; + return query; + } catch (std::exception const& err) { + fmt = Format::Colon; + return QueryContainer::from_colon_format(line); + } + }(); + query.parse(parser); + if (auto len = query.term_ids()->size(); len >= min_query_len && len <= max_query_len) { + std::cout << query.to_json() << '\n'; + } + }; + if (query_file) { + std::ifstream is(*query_file); + for_each_line(is, filter); + } else { + for_each_line(std::cin, filter); + } +} + +int main(int argc, char** argv) +{ + spdlog::drop(""); + spdlog::set_default_logger(spdlog::stderr_color_mt("")); + + std::size_t min_query_len = 1; + std::size_t max_query_len = std::numeric_limits::max(); + + pisa::App> app( + "Filters out empty queries against a v1 index."); + app.add_option("--min", min_query_len, "Minimum query legth to consider"); + app.add_option("--max", max_query_len, "Maximum query legth to consider"); + CLI11_PARSE(app, argc, argv); + + filter_queries(app.query_file(), app.term_lexicon(), app.stemmer(), min_query_len, max_query_len); + return 0; +} From cdc17f3ae858dd72418bfb5163540a74e077b96a Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Mon, 27 Apr 2020 14:08:05 +0000 Subject: [PATCH 03/12] CLI test --- .travis.yml | 10 +++++- include/pisa/query.hpp | 4 +-- src/query.cpp | 1 + test/cli/run.sh | 5 +++ test/cli/setup.sh | 25 ++++++++++++++ test/cli/test_filter_queries.sh | 58 +++++++++++++++++++++++++++++++++ tools/filter_queries.cpp | 14 ++++++-- 7 files changed, 111 insertions(+), 6 deletions(-) create mode 100755 test/cli/run.sh create mode 100755 test/cli/setup.sh create mode 100644 test/cli/test_filter_queries.sh diff --git a/.travis.yml b/.travis.yml index 58c62d718..596f5a76b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -56,7 +56,7 @@ matrix: apt: sources: *all_sources packages: ['g++-9'] - env: MATRIX_EVAL="CC=gcc-9 && CXX=g++-9 && COVERAGE=Off && DOCKER=Off" + env: MATRIX_EVAL="CC=gcc-9 && CXX=g++-9 && COVERAGE=Off && DOCKER=Off && TEST_CLI=On" - os: linux dist: xenial compiler: clang @@ -112,6 +112,11 @@ before_install: brew install ccache; export PATH="/usr/local/opt/ccache/libexec:$PATH"; fi + - if [[ "$TEST_CLI" == "On" ]]; then + git clone https://github.com/sstephenson/bats.git + cd bats + sudo ./install.sh /usr/local + fi - eval "${MATRIX_EVAL}" script: @@ -121,6 +126,9 @@ script: make -j2; if [[ "$TIDY" != "On" ]]; then CTEST_OUTPUT_ON_FAILURE=TRUE ctest -j2; + if [[ "$TEST_CLI" != "On" ]]; then + bash ../test/cli/run.sh + fi fi fi - if [[ "$CLANG_FORMAT" == "On" ]]; then diff --git a/include/pisa/query.hpp b/include/pisa/query.hpp index 1b39e3d51..6bb66be5c 100644 --- a/include/pisa/query.hpp +++ b/include/pisa/query.hpp @@ -22,8 +22,8 @@ using ParseFn = std::function(std::string const&)>; class QueryContainer; -/// Query is a special container that maintains important invariants, such as sorted term IDs, -/// and also has some additional data, like term weights, etc. +/// QueryRequest is a special container that maintains important invariants, such as sorted term +/// IDs, and also has some additional data, like term weights, etc. class QueryRequest { public: explicit QueryRequest(QueryContainer const& data); diff --git a/src/query.cpp b/src/query.cpp index 39dedc846..b2aacfca6 100644 --- a/src/query.cpp +++ b/src/query.cpp @@ -125,6 +125,7 @@ auto QueryContainer::parse(ParseFn parse_fn) -> QueryContainer& term_ids.push_back(term.id); } m_data->term_ids = std::move(term_ids); + m_data->processed_terms = std::move(processed_terms); return *this; } diff --git a/test/cli/run.sh b/test/cli/run.sh new file mode 100755 index 000000000..51486a077 --- /dev/null +++ b/test/cli/run.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash + +DIR=$(dirname "$0") +$DIR/setup.sh +bats $DIR/test_filter_queries.sh diff --git a/test/cli/setup.sh b/test/cli/setup.sh new file mode 100755 index 000000000..c55753f1a --- /dev/null +++ b/test/cli/setup.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash + +# This script should be executed within the build directory that is directly +# in the project directory, e.g., /path/to/pisa/build + +PISA_BIN="./bin" +export PATH="$PISA_BIN:$PATH" + +cat "../test/test_data/clueweb1k.plaintext" | parse_collection \ + --stemmer porter2 \ + --output "./fwd" \ + --format plaintext + +invert --input "./fwd" --output "./inv" + +compress_inverted_index --check \ + --encoding block_simdbp \ + --collection "./inv" \ + --output "./simdbp" + +create_wand_data \ + --scorer bm25 \ + --collection "./inv" \ + --output "./bm25.bmw" \ + --block-size 32 diff --git a/test/cli/test_filter_queries.sh b/test/cli/test_filter_queries.sh new file mode 100644 index 000000000..44e8f0f86 --- /dev/null +++ b/test/cli/test_filter_queries.sh @@ -0,0 +1,58 @@ +#!/usr/bin/env bats + +PISA_BIN="bin" +export PATH="$PISA_BIN:$PATH" + +function write_lines { + file=$1 + rm -f "$file" + shift + for line in "$@" + do + echo "$line" >> "$file" + done +} + + +function setup { + write_lines "$BATS_TMPDIR/queries.txt" "brooklyn tea house" "labradoodle" 'Tell your dog I said "hi"' + write_lines "$BATS_TMPDIR/stopwords.txt" "i" "your" +} + +@test "Filter from plain" { + result=$(cat $BATS_TMPDIR/queries.txt | filter-queries --stemmer porter2 --terms ./fwd.termlex) + echo "$result" > "$BATS_TMPDIR/result" + expected='{"query":"brooklyn tea house","term_ids":[6535,29194,15462],"terms":["brooklyn","tea","hous"]} +{"query":"Tell your dog I said \"hi\"","term_ids":[29287,32766,10396,15670,26032,15114],"terms":["tell","your","dog","i","said","hi"]}' + [[ "$result" = "$expected" ]] +} + +@test "Filter with minimum length" { + result=$(cat $BATS_TMPDIR/queries.txt | filter-queries --min 4 --stemmer porter2 --terms ./fwd.termlex) + echo "$result" > "$BATS_TMPDIR/result" + expected='{"query":"Tell your dog I said \"hi\"","term_ids":[29287,32766,10396,15670,26032,15114],"terms":["tell","your","dog","i","said","hi"]}' + [[ "$result" = "$expected" ]] +} + +@test "Filter with maximum length" { + result=$(cat $BATS_TMPDIR/queries.txt | filter-queries --max 4 --stemmer porter2 --terms ./fwd.termlex) + echo "$result" > "$BATS_TMPDIR/result" + expected='{"query":"brooklyn tea house","term_ids":[6535,29194,15462],"terms":["brooklyn","tea","hous"]}' + [[ "$result" = "$expected" ]] +} + +@test "Filter with stopwords" { + result=$(cat $BATS_TMPDIR/queries.txt | filter-queries --stopwords "$BATS_TMPDIR/stopwords.txt" --stemmer porter2 --terms ./fwd.termlex) + echo "$result" > "$BATS_TMPDIR/result" + expected='{"query":"brooklyn tea house","term_ids":[6535,29194,15462],"terms":["brooklyn","tea","hous"]} +{"query":"Tell your dog I said \"hi\"","term_ids":[29287,10396,26032,15114],"terms":["tell","dog","said","hi"]}' + [[ "$result" = "$expected" ]] +} + +@test "Filter without stemmer" { + result=$(cat $BATS_TMPDIR/queries.txt | filter-queries --terms ./fwd.termlex) + echo "$result" > "$BATS_TMPDIR/result" + expected='{"query":"brooklyn tea house","term_ids":[6535,29194],"terms":["brooklyn","tea"]} +{"query":"Tell your dog I said \"hi\"","term_ids":[29287,32766,10396,15670,26032,15114],"terms":["tell","your","dog","i","said","hi"]}' + [[ "$result" = "$expected" ]] +} diff --git a/tools/filter_queries.cpp b/tools/filter_queries.cpp index 0b637d979..17748126c 100644 --- a/tools/filter_queries.cpp +++ b/tools/filter_queries.cpp @@ -74,11 +74,13 @@ void filter_queries( std::optional const& query_file, std::optional const& term_lexicon, std::optional const& stemmer, + std::optional const& stopwords_filename, std::size_t min_query_len, std::size_t max_query_len) { std::optional fmt{}; - auto parser = [term_processor = TermProcessor(term_lexicon, {}, stemmer)](auto query) mutable { + auto parser = [term_processor = TermProcessor(term_lexicon, stopwords_filename, stemmer)]( + auto query) mutable { std::vector parsed_terms; pisa::TermTokenizer tokenizer(query); for (auto term_iter = tokenizer.begin(); term_iter != tokenizer.end(); ++term_iter) { @@ -127,12 +129,18 @@ int main(int argc, char** argv) std::size_t min_query_len = 1; std::size_t max_query_len = std::numeric_limits::max(); - pisa::App> app( + pisa::App> app( "Filters out empty queries against a v1 index."); app.add_option("--min", min_query_len, "Minimum query legth to consider"); app.add_option("--max", max_query_len, "Maximum query legth to consider"); CLI11_PARSE(app, argc, argv); - filter_queries(app.query_file(), app.term_lexicon(), app.stemmer(), min_query_len, max_query_len); + filter_queries( + app.query_file(), + app.term_lexicon(), + app.stemmer(), + app.stop_words(), + min_query_len, + max_query_len); return 0; } From b0e5d1a7a84f05292aa92fec4772ae06b9da18fa Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Mon, 27 Apr 2020 15:57:42 +0000 Subject: [PATCH 04/12] Fix .travis.yml syntax --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 596f5a76b..685313f01 100644 --- a/.travis.yml +++ b/.travis.yml @@ -127,7 +127,7 @@ script: if [[ "$TIDY" != "On" ]]; then CTEST_OUTPUT_ON_FAILURE=TRUE ctest -j2; if [[ "$TEST_CLI" != "On" ]]; then - bash ../test/cli/run.sh + bash ../test/cli/run.sh; fi fi fi From 6e2ab62bf185afb0cce0ddf8b37d269f561c80d0 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Mon, 27 Apr 2020 16:08:55 +0000 Subject: [PATCH 05/12] Fix .travis.yml syntax --- .travis.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.travis.yml b/.travis.yml index 685313f01..34604e9d5 100644 --- a/.travis.yml +++ b/.travis.yml @@ -113,9 +113,9 @@ before_install: export PATH="/usr/local/opt/ccache/libexec:$PATH"; fi - if [[ "$TEST_CLI" == "On" ]]; then - git clone https://github.com/sstephenson/bats.git - cd bats - sudo ./install.sh /usr/local + git clone https://github.com/sstephenson/bats.git; + cd bats; + sudo ./install.sh /usr/local; fi - eval "${MATRIX_EVAL}" From 2cce2cdaea040abc7e11b2e0982c0c135177f5a1 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Mon, 27 Apr 2020 16:42:21 +0000 Subject: [PATCH 06/12] Fix when cli test are executed --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 34604e9d5..241028838 100644 --- a/.travis.yml +++ b/.travis.yml @@ -126,7 +126,7 @@ script: make -j2; if [[ "$TIDY" != "On" ]]; then CTEST_OUTPUT_ON_FAILURE=TRUE ctest -j2; - if [[ "$TEST_CLI" != "On" ]]; then + if [[ "$TEST_CLI" == "On" ]]; then bash ../test/cli/run.sh; fi fi From 1838258a1d0c69c0edf2a43dccd9c81ea348077d Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Tue, 28 Apr 2020 23:47:18 +0000 Subject: [PATCH 07/12] Refactor out common code from tool --- include/pisa/query.hpp | 37 ++++++++- include/pisa/query/parser.hpp | 52 +++++++++++++ src/query.cpp | 39 ++++++++++ src/query/parser.cpp | 94 +++++++++++++++++++++++ test/cli/test_filter_queries.sh | 26 +++++-- test/test_query.cpp | 6 +- test/test_query_parser.cpp | 28 +++++++ tools/filter_queries.cpp | 130 ++++++++------------------------ 8 files changed, 302 insertions(+), 110 deletions(-) create mode 100644 include/pisa/query/parser.hpp create mode 100644 src/query/parser.cpp create mode 100644 test/test_query_parser.cpp diff --git a/include/pisa/query.hpp b/include/pisa/query.hpp index 6bb66be5c..f40055e06 100644 --- a/include/pisa/query.hpp +++ b/include/pisa/query.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -12,13 +13,13 @@ namespace pisa { struct QueryContainerInner; -struct ParsedTerm { +struct ResolvedTerm { std::uint32_t id; std::string term; }; using TermProcessorFn = std::function(std::string)>; -using ParseFn = std::function(std::string const&)>; +using ParseFn = std::function(std::string const&)>; class QueryContainer; @@ -104,4 +105,36 @@ class QueryContainer { std::unique_ptr m_data; }; +enum class Format { Json, Colon }; + +class QueryReader { + public: + /// Open reader from file. + static auto from_file(std::string const& file) -> QueryReader; + /// Open reader from stdin. + static auto from_stdin() -> QueryReader; + + /// Read next query or return `nullopt` if stream has ended. + [[nodiscard]] auto next() -> std::optional; + + /// Execute `fn(q)` for each query `q`. + template + void for_each(Fn&& fn) + { + auto query = next(); + while (query) { + fn(std::move(*query)); + query = next(); + } + } + + private: + explicit QueryReader(std::unique_ptr stream, std::istream& stream_ref); + + std::unique_ptr m_stream; + std::istream& m_stream_ref; + std::string m_line_buf{}; + std::optional m_format{}; +}; + } // namespace pisa diff --git a/include/pisa/query/parser.hpp b/include/pisa/query/parser.hpp new file mode 100644 index 000000000..589c5c52b --- /dev/null +++ b/include/pisa/query/parser.hpp @@ -0,0 +1,52 @@ +#pragma once + +#include +#include +#include + +#include "query.hpp" + +namespace pisa { + +using TermResolver = std::function(std::string)>; + +struct StandardTermResolverParams; + +/// Provides a standard implementation of `TermResolver`. +class StandardTermResolver { + public: + StandardTermResolver( + std::string const& term_lexicon_path, + std::optional const& stopwords_filename, + std::optional const& stemmer_type); + StandardTermResolver(StandardTermResolver const&); + StandardTermResolver(StandardTermResolver&&) noexcept; + StandardTermResolver& operator=(StandardTermResolver const&); + StandardTermResolver& operator=(StandardTermResolver&&) noexcept; + ~StandardTermResolver(); + + [[nodiscard]] auto operator()(std::string token) const -> std::optional; + + private: + [[nodiscard]] auto is_stopword(std::uint32_t const term) const -> bool; + + std::unique_ptr m_self; +}; + +/// Parses a query string to processed terms. +class QueryParser { + public: + explicit QueryParser(TermResolver term_processor); + /// Given a query string, it returns a list of (possibly processed) terms. + /// + /// Possible transformations of terms include lower-casing and stemming. + /// Some terms could be also removed, e.g., because they are on a list of + /// stop words. The exact implementation depends on the term processor + /// passed to the constructor. + auto operator()(std::string const&) -> std::vector; + + private: + TermResolver m_term_resolver; +}; + +} // namespace pisa diff --git a/src/query.cpp b/src/query.cpp index b2aacfca6..b202ce172 100644 --- a/src/query.cpp +++ b/src/query.cpp @@ -1,6 +1,8 @@ #include "query.hpp" #include +#include +#include #include #include @@ -224,4 +226,41 @@ auto QueryContainer::from_colon_format(std::string_view line) -> QueryContainer return query; } +auto QueryReader::from_file(std::string const& file) -> QueryReader +{ + auto input = std::make_unique(file); + auto& ref = *input; + return QueryReader(std::move(input), ref); +} + +auto QueryReader::from_stdin() -> QueryReader +{ + return QueryReader(nullptr, std::cin); +} + +QueryReader::QueryReader(std::unique_ptr input, std::istream& stream_ref) + : m_stream(std::move(input)), m_stream_ref(stream_ref) +{} + +auto QueryReader::next() -> std::optional +{ + if (std::getline(m_stream_ref, m_line_buf)) { + if (m_format) { + if (*m_format == Format::Json) { + return QueryContainer::from_json(m_line_buf); + } + return QueryContainer::from_colon_format(m_line_buf); + } + try { + auto query = QueryContainer::from_json(m_line_buf); + m_format = Format::Json; + return query; + } catch (std::exception const& err) { + m_format = Format::Colon; + return QueryContainer::from_colon_format(m_line_buf); + } + } + return std::nullopt; +} + } // namespace pisa diff --git a/src/query/parser.cpp b/src/query/parser.cpp new file mode 100644 index 000000000..178a425bf --- /dev/null +++ b/src/query/parser.cpp @@ -0,0 +1,94 @@ +#include + +#include "io.hpp" +#include "payload_vector.hpp" +#include "query.hpp" +#include "query/parser.hpp" +#include "query/term_processor.hpp" +#include "tokenizer.hpp" + +namespace pisa { + +StandardTermResolver::StandardTermResolver(StandardTermResolver const& other) + : m_self(std::make_unique(*other.m_self)) +{} +StandardTermResolver::StandardTermResolver(StandardTermResolver&&) noexcept = default; +StandardTermResolver& StandardTermResolver::operator=(StandardTermResolver const& other) +{ + m_self = std::make_unique(*other.m_self); + return *this; +} +StandardTermResolver& StandardTermResolver::operator=(StandardTermResolver&&) noexcept = default; +StandardTermResolver::~StandardTermResolver() = default; + +struct StandardTermResolverParams { + std::vector stopwords; + std::function(std::string const&)> to_id; + std::function transform; +}; + +StandardTermResolver::StandardTermResolver( + std::string const& term_lexicon_path, + std::optional const& stopwords_filename, + std::optional const& stemmer_type) + : m_self(std::make_unique()) +{ + auto source = std::make_shared(term_lexicon_path.c_str()); + auto terms = pisa::Payload_Vector<>::from(*source); + + m_self->to_id = [source = std::move(source), terms](auto str) -> std::optional { + auto pos = std::lower_bound(terms.begin(), terms.end(), std::string_view(str)); + if (*pos == std::string_view(str)) { + return std::distance(terms.begin(), pos); + } + return std::nullopt; + }; + + m_self->transform = pisa::term_processor(stemmer_type); + + if (stopwords_filename) { + std::ifstream is(*stopwords_filename); + pisa::io::for_each_line(is, [&](auto&& word) { + if (auto term_id = m_self->to_id(std::move(word)); term_id.has_value()) { + m_self->stopwords.push_back(*term_id); + } + }); + std::sort(m_self->stopwords.begin(), m_self->stopwords.end()); + } +} + +auto StandardTermResolver::operator()(std::string token) const -> std::optional +{ + token = m_self->transform(token); + auto id = m_self->to_id(token); + if (not id) { + return std::nullopt; + } + if (is_stopword(*id)) { + return std::nullopt; + } + return pisa::ResolvedTerm{*id, token}; +} + +auto StandardTermResolver::is_stopword(std::uint32_t const term) const -> bool +{ + auto pos = std::lower_bound(m_self->stopwords.begin(), m_self->stopwords.end(), term); + return pos != m_self->stopwords.end() && *pos == term; +} + +QueryParser::QueryParser(TermResolver term_resolver) : m_term_resolver(std::move(term_resolver)) {} + +auto QueryParser::operator()(std::string const& query) -> std::vector +{ + TermTokenizer tokenizer(query); + std::vector terms; + for (auto term_iter = tokenizer.begin(); term_iter != tokenizer.end(); ++term_iter) { + auto term = m_term_resolver(*term_iter); + if (term) { + terms.push_back(std::move(*term)); + } + } + return terms; +} + +} // namespace pisa diff --git a/test/cli/test_filter_queries.sh b/test/cli/test_filter_queries.sh index 44e8f0f86..2fa486074 100644 --- a/test/cli/test_filter_queries.sh +++ b/test/cli/test_filter_queries.sh @@ -21,7 +21,6 @@ function setup { @test "Filter from plain" { result=$(cat $BATS_TMPDIR/queries.txt | filter-queries --stemmer porter2 --terms ./fwd.termlex) - echo "$result" > "$BATS_TMPDIR/result" expected='{"query":"brooklyn tea house","term_ids":[6535,29194,15462],"terms":["brooklyn","tea","hous"]} {"query":"Tell your dog I said \"hi\"","term_ids":[29287,32766,10396,15670,26032,15114],"terms":["tell","your","dog","i","said","hi"]}' [[ "$result" = "$expected" ]] @@ -29,21 +28,18 @@ function setup { @test "Filter with minimum length" { result=$(cat $BATS_TMPDIR/queries.txt | filter-queries --min 4 --stemmer porter2 --terms ./fwd.termlex) - echo "$result" > "$BATS_TMPDIR/result" expected='{"query":"Tell your dog I said \"hi\"","term_ids":[29287,32766,10396,15670,26032,15114],"terms":["tell","your","dog","i","said","hi"]}' [[ "$result" = "$expected" ]] } @test "Filter with maximum length" { result=$(cat $BATS_TMPDIR/queries.txt | filter-queries --max 4 --stemmer porter2 --terms ./fwd.termlex) - echo "$result" > "$BATS_TMPDIR/result" expected='{"query":"brooklyn tea house","term_ids":[6535,29194,15462],"terms":["brooklyn","tea","hous"]}' [[ "$result" = "$expected" ]] } @test "Filter with stopwords" { result=$(cat $BATS_TMPDIR/queries.txt | filter-queries --stopwords "$BATS_TMPDIR/stopwords.txt" --stemmer porter2 --terms ./fwd.termlex) - echo "$result" > "$BATS_TMPDIR/result" expected='{"query":"brooklyn tea house","term_ids":[6535,29194,15462],"terms":["brooklyn","tea","hous"]} {"query":"Tell your dog I said \"hi\"","term_ids":[29287,10396,26032,15114],"terms":["tell","dog","said","hi"]}' [[ "$result" = "$expected" ]] @@ -51,8 +47,28 @@ function setup { @test "Filter without stemmer" { result=$(cat $BATS_TMPDIR/queries.txt | filter-queries --terms ./fwd.termlex) - echo "$result" > "$BATS_TMPDIR/result" expected='{"query":"brooklyn tea house","term_ids":[6535,29194],"terms":["brooklyn","tea"]} {"query":"Tell your dog I said \"hi\"","term_ids":[29287,32766,10396,15670,26032,15114],"terms":["tell","your","dog","i","said","hi"]}' [[ "$result" = "$expected" ]] } + +@test "Accept JSON" { + echo '{"query":"brooklyn tea house"}' > "$BATS_TMPDIR/queries.json" + result=$(cat $BATS_TMPDIR/queries.json | filter-queries --terms ./fwd.termlex) + expected='{"query":"brooklyn tea house","term_ids":[6535,29194],"terms":["brooklyn","tea"]}' + [[ "$result" = "$expected" ]] +} + +@test "Accept JSON without --terms if already parsed" { + echo '{"term_ids":[6535,29194]}' > "$BATS_TMPDIR/queries.json" + result=$(cat $BATS_TMPDIR/queries.json | filter-queries) + expected='{"term_ids":[6535,29194]}' + [[ "$result" = "$expected" ]] +} + +@test "Fail when no --terms and not parsed" { + echo '{"query":"brooklyn tea house"}' > "$BATS_TMPDIR/queries.json" + run filter-queries < $BATS_TMPDIR/queries.json + [[ "$status" -eq 1 ]] + [[ "$output" = *"[error] Unresoved queries (without IDs) require term lexicon." ]] +} diff --git a/test/test_query.cpp b/test/test_query.cpp index e3e6128e7..00b3aea0e 100644 --- a/test/test_query.cpp +++ b/test/test_query.cpp @@ -51,12 +51,12 @@ TEST_CASE("Parse query") query.parse([&](auto&& q) { std::istringstream is(q); std::string term; - std::vector parsed_terms; + std::vector parsed_terms; while (is >> term) { if (auto t = term_proc(term); t) { if (auto pos = std::find(lexicon.begin(), lexicon.end(), *t); pos != lexicon.end()) { auto id = static_cast(std::distance(lexicon.begin(), pos)); - parsed_terms.push_back(pisa::ParsedTerm{id, *t}); + parsed_terms.push_back(pisa::ResolvedTerm{id, *t}); } } } @@ -70,7 +70,7 @@ TEST_CASE("Parsing throws without raw query") std::vector term_ids{1, 0, 3}; auto query = QueryContainer::from_term_ids(term_ids); REQUIRE_THROWS_AS( - query.parse([](auto&& str) { return std::vector{}; }), std::domain_error); + query.parse([](auto&& str) { return std::vector{}; }), std::domain_error); } TEST_CASE("Parse query container from colon-delimited format") diff --git a/test/test_query_parser.cpp b/test/test_query_parser.cpp new file mode 100644 index 000000000..32ba2b32c --- /dev/null +++ b/test/test_query_parser.cpp @@ -0,0 +1,28 @@ +#define CATCH_CONFIG_MAIN + +#include + +#include + +#include "query/parser.hpp" + +using pisa::QueryContainer; +using pisa::QueryParser; + +TEST_CASE("Parse with lower-case processor and stop word") +{ + std::uint32_t init_id = 0; + auto term_proc = [id = init_id](auto&& term) mutable { + std::transform( + term.begin(), term.end(), term.begin(), [](unsigned char c) { return std::tolower(c); }); + if (term == "house") { + return std::optional{}; + } + return std::optional{pisa::ResolvedTerm{id++, term}}; + }; + QueryParser parser(term_proc); + auto terms = parser("Brooklyn tea house"); + REQUIRE(terms.size() == 2); + REQUIRE(terms[0].term == "brooklyn"); + REQUIRE(terms[1].term == "tea"); +} diff --git a/tools/filter_queries.cpp b/tools/filter_queries.cpp index 17748126c..e29a66b44 100644 --- a/tools/filter_queries.cpp +++ b/tools/filter_queries.cpp @@ -6,68 +6,17 @@ #include "app.hpp" #include "query.hpp" +#include "query/parser.hpp" #include "tokenizer.hpp" namespace arg = pisa::arg; + using pisa::QueryContainer; +using pisa::QueryParser; +using pisa::QueryReader; +using pisa::StandardTermResolver; using pisa::io::for_each_line; -class TermProcessor { - private: - std::unordered_set stopwords; - - std::function(std::string const&)> m_to_id; - pisa::Stemmer_t m_stemmer; - - public: - TermProcessor( - std::optional const& terms_file, - std::optional const& stopwords_filename, - std::optional const& stemmer_type) - { - auto source = std::make_shared(terms_file->c_str()); - auto terms = pisa::Payload_Vector<>::from(*source); - - m_to_id = [source = std::move(source), terms](auto str) -> std::optional { - // Note: the lexicographical order of the terms matters. - auto pos = std::lower_bound(terms.begin(), terms.end(), std::string_view(str)); - if (*pos == std::string_view(str)) { - return std::distance(terms.begin(), pos); - } - return std::nullopt; - }; - - m_stemmer = pisa::term_processor(stemmer_type); - - if (stopwords_filename) { - std::ifstream is(*stopwords_filename); - pisa::io::for_each_line(is, [&](auto&& word) { - if (auto processed_term = m_to_id(std::move(word)); processed_term.has_value()) { - stopwords.insert(*processed_term); - } - }); - } - } - - [[nodiscard]] std::optional operator()(std::string token) - { - token = m_stemmer(token); - auto id = m_to_id(token); - if (not id) { - return std::nullopt; - } - if (is_stopword(*id)) { - return std::nullopt; - } - return pisa::ParsedTerm{*id, token}; - } - - [[nodiscard]] auto is_stopword(std::uint32_t const term) const -> bool - { - return stopwords.find(term) != stopwords.end(); - } -}; - enum class Format { Json, Colon }; void filter_queries( @@ -78,47 +27,23 @@ void filter_queries( std::size_t min_query_len, std::size_t max_query_len) { - std::optional fmt{}; - auto parser = [term_processor = TermProcessor(term_lexicon, stopwords_filename, stemmer)]( - auto query) mutable { - std::vector parsed_terms; - pisa::TermTokenizer tokenizer(query); - for (auto term_iter = tokenizer.begin(); term_iter != tokenizer.end(); ++term_iter) { - auto term = term_processor(*term_iter); - if (term) { - parsed_terms.push_back(std::move(*term)); - } + auto reader = [&] { + if (query_file) { + return QueryReader::from_file(*query_file); } - return parsed_terms; - }; - auto filter = [&](auto&& line) { - auto query = [&] { - if (fmt) { - if (*fmt == Format::Json) { - return QueryContainer::from_json(line); - } - return QueryContainer::from_colon_format(line); - } - try { - auto query = QueryContainer::from_json(line); - fmt = Format::Json; - return query; - } catch (std::exception const& err) { - fmt = Format::Colon; - return QueryContainer::from_colon_format(line); + return QueryReader::from_stdin(); + }(); + reader.for_each([&](auto query) { + if (not query.term_ids()) { + if (not term_lexicon) { + throw std::runtime_error("Unresoved queries (without IDs) require term lexicon."); } - }(); - query.parse(parser); + query.parse(QueryParser(StandardTermResolver(*term_lexicon, stopwords_filename, stemmer))); + } if (auto len = query.term_ids()->size(); len >= min_query_len && len <= max_query_len) { std::cout << query.to_json() << '\n'; } - }; - if (query_file) { - std::ifstream is(*query_file); - for_each_line(is, filter); - } else { - for_each_line(std::cin, filter); - } + }); } int main(int argc, char** argv) @@ -135,12 +60,17 @@ int main(int argc, char** argv) app.add_option("--max", max_query_len, "Maximum query legth to consider"); CLI11_PARSE(app, argc, argv); - filter_queries( - app.query_file(), - app.term_lexicon(), - app.stemmer(), - app.stop_words(), - min_query_len, - max_query_len); - return 0; + try { + filter_queries( + app.query_file(), + app.term_lexicon(), + app.stemmer(), + app.stop_words(), + min_query_len, + max_query_len); + return 0; + } catch (std::runtime_error const& err) { + spdlog::error(err.what()); + return 1; + } } From 7107f65efa82bc611ef7d0ed83703a15b22ee5ce Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Fri, 1 May 2020 14:53:28 +0000 Subject: [PATCH 08/12] Small refactoring and term resolver tests --- include/pisa/query.hpp | 2 + include/pisa/query/query_parser.hpp | 26 +++++++ .../query/{parser.hpp => term_resolver.hpp} | 31 +++++---- src/query.cpp | 12 ++++ src/query/query_parser.cpp | 27 ++++++++ src/query/{parser.cpp => term_resolver.cpp} | 42 ++++++----- test/test_query_parser.cpp | 2 +- test/test_term_resolver.cpp | 69 +++++++++++++++++++ tools/filter_queries.cpp | 48 ++++--------- 9 files changed, 189 insertions(+), 70 deletions(-) create mode 100644 include/pisa/query/query_parser.hpp rename include/pisa/query/{parser.hpp => term_resolver.hpp} (64%) create mode 100644 src/query/query_parser.cpp rename src/query/{parser.cpp => term_resolver.cpp} (76%) create mode 100644 test/test_term_resolver.cpp diff --git a/include/pisa/query.hpp b/include/pisa/query.hpp index f40055e06..eb6488c0c 100644 --- a/include/pisa/query.hpp +++ b/include/pisa/query.hpp @@ -45,6 +45,8 @@ class QueryContainer { QueryContainer& operator=(QueryContainer&&) noexcept; ~QueryContainer(); + [[nodiscard]] auto operator==(QueryContainer const& other) const noexcept -> bool; + /// Constructs a query from a raw string. [[nodiscard]] static auto raw(std::string query_string) -> QueryContainer; diff --git a/include/pisa/query/query_parser.hpp b/include/pisa/query/query_parser.hpp new file mode 100644 index 000000000..a3500962a --- /dev/null +++ b/include/pisa/query/query_parser.hpp @@ -0,0 +1,26 @@ +#pragma once + +#include + +#include "query.hpp" +#include "term_resolver.hpp" + +namespace pisa { + +/// Parses a query string to processed terms. +class QueryParser { + public: + explicit QueryParser(TermResolver term_processor); + /// Given a query string, it returns a list of (possibly processed) terms. + /// + /// Possible transformations of terms include lower-casing and stemming. + /// Some terms could be also removed, e.g., because they are on a list of + /// stop words. The exact implementation depends on the term processor + /// passed to the constructor. + auto operator()(std::string const&) -> std::vector; + + private: + TermResolver m_term_resolver; +}; + +} // namespace pisa diff --git a/include/pisa/query/parser.hpp b/include/pisa/query/term_resolver.hpp similarity index 64% rename from include/pisa/query/parser.hpp rename to include/pisa/query/term_resolver.hpp index 589c5c52b..fc9088a48 100644 --- a/include/pisa/query/parser.hpp +++ b/include/pisa/query/term_resolver.hpp @@ -8,6 +8,10 @@ namespace pisa { +/// Thrown if expected resolver but none found. +struct MissingResolverError { +}; + using TermResolver = std::function(std::string)>; struct StandardTermResolverParams; @@ -33,20 +37,17 @@ class StandardTermResolver { std::unique_ptr m_self; }; -/// Parses a query string to processed terms. -class QueryParser { - public: - explicit QueryParser(TermResolver term_processor); - /// Given a query string, it returns a list of (possibly processed) terms. - /// - /// Possible transformations of terms include lower-casing and stemming. - /// Some terms could be also removed, e.g., because they are on a list of - /// stop words. The exact implementation depends on the term processor - /// passed to the constructor. - auto operator()(std::string const&) -> std::vector; - - private: - TermResolver m_term_resolver; -}; +/// Reads queries from `query_file`, resolves them with `term_resolver`, filters by +/// query length (number of resolved terms in the query), and prints the selected +/// queries to `out`. +/// +/// \throws MissingResolverError When no resolver passed but queries don't have IDs resolved. +// +void filter_queries( + std::optional const& query_file, + std::optional term_resolver, + std::size_t min_query_len, + std::size_t max_query_len, + std::ostream& out); } // namespace pisa diff --git a/src/query.cpp b/src/query.cpp index b202ce172..1b54efe66 100644 --- a/src/query.cpp +++ b/src/query.cpp @@ -36,6 +36,13 @@ struct QueryContainerInner { std::optional> processed_terms; std::optional> term_ids; std::optional threshold; + + [[nodiscard]] auto operator==(QueryContainerInner const& other) const noexcept -> bool + { + return id == other.id && query_string == other.query_string + && processed_terms == other.processed_terms && term_ids == other.term_ids + && threshold == other.threshold; + } }; QueryContainer::QueryContainer() : m_data(std::make_unique()) {} @@ -52,6 +59,11 @@ QueryContainer& QueryContainer::operator=(QueryContainer const& other) QueryContainer& QueryContainer::operator=(QueryContainer&&) noexcept = default; QueryContainer::~QueryContainer() = default; +auto QueryContainer::operator==(QueryContainer const& other) const noexcept -> bool +{ + return *m_data == *other.m_data; +} + auto QueryContainer::raw(std::string query_string) -> QueryContainer { QueryContainer query; diff --git a/src/query/query_parser.cpp b/src/query/query_parser.cpp new file mode 100644 index 000000000..fc64a7891 --- /dev/null +++ b/src/query/query_parser.cpp @@ -0,0 +1,27 @@ +#include + +#include "io.hpp" +#include "payload_vector.hpp" +#include "query.hpp" +#include "query/query_parser.hpp" +#include "query/term_resolver.hpp" +#include "tokenizer.hpp" + +namespace pisa { + +QueryParser::QueryParser(TermResolver term_resolver) : m_term_resolver(std::move(term_resolver)) {} + +auto QueryParser::operator()(std::string const& query) -> std::vector +{ + TermTokenizer tokenizer(query); + std::vector terms; + for (auto term_iter = tokenizer.begin(); term_iter != tokenizer.end(); ++term_iter) { + auto term = m_term_resolver(*term_iter); + if (term) { + terms.push_back(std::move(*term)); + } + } + return terms; +} + +} // namespace pisa diff --git a/src/query/parser.cpp b/src/query/term_resolver.cpp similarity index 76% rename from src/query/parser.cpp rename to src/query/term_resolver.cpp index 178a425bf..7d00f59ad 100644 --- a/src/query/parser.cpp +++ b/src/query/term_resolver.cpp @@ -1,11 +1,6 @@ -#include - -#include "io.hpp" -#include "payload_vector.hpp" -#include "query.hpp" -#include "query/parser.hpp" +#include "query/term_resolver.hpp" +#include "query/query_parser.hpp" #include "query/term_processor.hpp" -#include "tokenizer.hpp" namespace pisa { @@ -76,19 +71,30 @@ auto StandardTermResolver::is_stopword(std::uint32_t const term) const -> bool return pos != m_self->stopwords.end() && *pos == term; } -QueryParser::QueryParser(TermResolver term_resolver) : m_term_resolver(std::move(term_resolver)) {} - -auto QueryParser::operator()(std::string const& query) -> std::vector +void filter_queries( + std::optional const& query_file, + std::optional term_resolver, + std::size_t min_query_len, + std::size_t max_query_len, + std::ostream& out) { - TermTokenizer tokenizer(query); - std::vector terms; - for (auto term_iter = tokenizer.begin(); term_iter != tokenizer.end(); ++term_iter) { - auto term = m_term_resolver(*term_iter); - if (term) { - terms.push_back(std::move(*term)); + auto reader = [&] { + if (query_file) { + return QueryReader::from_file(*query_file); } - } - return terms; + return QueryReader::from_stdin(); + }(); + reader.for_each([&](auto query) { + if (not query.term_ids()) { + if (not term_resolver) { + throw MissingResolverError{}; + } + query.parse(QueryParser(*term_resolver)); + } + if (auto len = query.term_ids()->size(); len >= min_query_len && len <= max_query_len) { + out << query.to_json() << '\n'; + } + }); } } // namespace pisa diff --git a/test/test_query_parser.cpp b/test/test_query_parser.cpp index 32ba2b32c..95f028124 100644 --- a/test/test_query_parser.cpp +++ b/test/test_query_parser.cpp @@ -4,7 +4,7 @@ #include -#include "query/parser.hpp" +#include "query/query_parser.hpp" using pisa::QueryContainer; using pisa::QueryParser; diff --git a/test/test_term_resolver.cpp b/test/test_term_resolver.cpp new file mode 100644 index 000000000..913060b85 --- /dev/null +++ b/test/test_term_resolver.cpp @@ -0,0 +1,69 @@ +#define CATCH_CONFIG_MAIN + +#include + +#include + +#include "io.hpp" +#include "query/term_resolver.hpp" +#include "temporary_directory.hpp" + +using pisa::QueryContainer; +using pisa::StandardTermResolver; + +TEST_CASE("Filter queries") +{ + std::uint32_t id = 0; + auto term_resolver = [&id](auto&& term) mutable { + return std::optional{pisa::ResolvedTerm{id++, term}}; + }; + Temporary_Directory tmp; + auto input = (tmp.path() / "input.txt"); + { + std::ofstream os(input.c_str()); + os << "a b c d\n"; + os << "e\n"; + os << "f g h i j\n"; + os << "k l m\n"; + os << "n o\n"; + } + + SECTION("Between 2 and 4") + { + std::ostringstream os; + pisa::filter_queries( + std::make_optional(input.string()), std::make_optional(term_resolver), 2, 4, os); + std::vector queries; + std::istringstream is(os.str()); + pisa::io::for_each_line( + is, [&queries](auto&& line) { queries.push_back(QueryContainer::from_json(line)); }); + REQUIRE(queries.size() == 3); + REQUIRE(*queries[0].terms() == std::vector{"a", "b", "c", "d"}); + REQUIRE(*queries[0].term_ids() == std::vector{0, 1, 2, 3}); + REQUIRE(*queries[1].terms() == std::vector{"k", "l", "m"}); + REQUIRE(*queries[1].term_ids() == std::vector{10, 11, 12}); + REQUIRE(*queries[2].terms() == std::vector{"n", "o"}); + REQUIRE(*queries[2].term_ids() == std::vector{13, 14}); + + SECTION("Don't fail if no resolver but IDs already resolved") + { + auto json_input = (tmp.path() / "input.json"); + { + std::ofstream json_out(json_input.c_str()); + for (auto&& query: queries) { + json_out << query.to_json() << '\n'; + } + } + std::ostringstream output; + pisa::filter_queries(std::make_optional(json_input.string()), std::nullopt, 2, 4, output); + REQUIRE(output.str() == os.str()); + } + } + + SECTION("Fail without IDs and resolver") + { + REQUIRE_THROWS_AS( + pisa::filter_queries(std::make_optional(input.string()), std::nullopt, 2, 4, std::cerr), + pisa::MissingResolverError); + } +} diff --git a/tools/filter_queries.cpp b/tools/filter_queries.cpp index e29a66b44..1b1507c81 100644 --- a/tools/filter_queries.cpp +++ b/tools/filter_queries.cpp @@ -6,46 +6,20 @@ #include "app.hpp" #include "query.hpp" -#include "query/parser.hpp" +#include "query/query_parser.hpp" +#include "query/term_resolver.hpp" #include "tokenizer.hpp" namespace arg = pisa::arg; +using pisa::filter_queries; using pisa::QueryContainer; using pisa::QueryParser; using pisa::QueryReader; using pisa::StandardTermResolver; +using pisa::TermResolver; using pisa::io::for_each_line; -enum class Format { Json, Colon }; - -void filter_queries( - std::optional const& query_file, - std::optional const& term_lexicon, - std::optional const& stemmer, - std::optional const& stopwords_filename, - std::size_t min_query_len, - std::size_t max_query_len) -{ - auto reader = [&] { - if (query_file) { - return QueryReader::from_file(*query_file); - } - return QueryReader::from_stdin(); - }(); - reader.for_each([&](auto query) { - if (not query.term_ids()) { - if (not term_lexicon) { - throw std::runtime_error("Unresoved queries (without IDs) require term lexicon."); - } - query.parse(QueryParser(StandardTermResolver(*term_lexicon, stopwords_filename, stemmer))); - } - if (auto len = query.term_ids()->size(); len >= min_query_len && len <= max_query_len) { - std::cout << query.to_json() << '\n'; - } - }); -} - int main(int argc, char** argv) { spdlog::drop(""); @@ -60,15 +34,17 @@ int main(int argc, char** argv) app.add_option("--max", max_query_len, "Maximum query legth to consider"); CLI11_PARSE(app, argc, argv); + std::optional term_resolver{}; + if (app.term_lexicon()) { + term_resolver = StandardTermResolver(*app.term_lexicon(), app.stop_words(), app.stemmer()); + } + try { filter_queries( - app.query_file(), - app.term_lexicon(), - app.stemmer(), - app.stop_words(), - min_query_len, - max_query_len); + app.query_file(), std::move(term_resolver), min_query_len, max_query_len, std::cout); return 0; + } catch (pisa::MissingResolverError err) { + spdlog::error("Unresoved queries(without IDs) require term lexicon."); } catch (std::runtime_error const& err) { spdlog::error(err.what()); return 1; From ede9c984a745fae61f2120be4dfa23892a4bdad6 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Fri, 1 May 2020 14:55:13 +0000 Subject: [PATCH 09/12] Fix tool description --- tools/filter_queries.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tools/filter_queries.cpp b/tools/filter_queries.cpp index 1b1507c81..dbc89469e 100644 --- a/tools/filter_queries.cpp +++ b/tools/filter_queries.cpp @@ -28,8 +28,7 @@ int main(int argc, char** argv) std::size_t min_query_len = 1; std::size_t max_query_len = std::numeric_limits::max(); - pisa::App> app( - "Filters out empty queries against a v1 index."); + pisa::App> app("Filters queries by their length"); app.add_option("--min", min_query_len, "Minimum query legth to consider"); app.add_option("--max", max_query_len, "Maximum query legth to consider"); CLI11_PARSE(app, argc, argv); From b8f625cd5fd08f53297a151f6d391e275ab97ebd Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Sun, 3 May 2020 01:42:17 +0000 Subject: [PATCH 10/12] Multiple thresholds per query --- include/pisa/query.hpp | 17 +++++++--- src/query.cpp | 73 ++++++++++++++++++++++++++++++++++-------- test/test_query.cpp | 10 +++--- 3 files changed, 76 insertions(+), 24 deletions(-) diff --git a/include/pisa/query.hpp b/include/pisa/query.hpp index eb6488c0c..dbec6ba51 100644 --- a/include/pisa/query.hpp +++ b/include/pisa/query.hpp @@ -27,12 +27,14 @@ class QueryContainer; /// IDs, and also has some additional data, like term weights, etc. class QueryRequest { public: - explicit QueryRequest(QueryContainer const& data); + explicit QueryRequest(QueryContainer const& data, std::size_t k); [[nodiscard]] auto term_ids() const -> gsl::span; [[nodiscard]] auto threshold() const -> std::optional; + [[nodiscard]] auto k() const -> std::optional; private: + std::size_t m_k; std::optional m_threshold{}; std::vector m_term_ids{}; }; @@ -86,7 +88,9 @@ class QueryContainer { [[nodiscard]] auto string() const noexcept -> std::optional const&; [[nodiscard]] auto terms() const noexcept -> std::optional> const&; [[nodiscard]] auto term_ids() const noexcept -> std::optional> const&; - [[nodiscard]] auto threshold() const noexcept -> std::optional const&; + [[nodiscard]] auto threshold(std::size_t k) const noexcept -> std::optional; + [[nodiscard]] auto thresholds() const noexcept + -> std::vector> const&; /// Sets the raw string. [[nodiscard]] auto string(std::string) -> QueryContainer&; @@ -96,11 +100,14 @@ class QueryContainer { /// \throws std::domain_error when raw string is not set auto parse(ParseFn parse_fn) -> QueryContainer&; - /// Sets the query score threshold. - auto threshold(float score) -> QueryContainer&; + /// Sets the query score threshold for `k`. + /// + /// If another threshold for the same `k` exists, it will be replaced, + /// and `true` will be returned. Otherwise, `false` will be returned. + auto add_threshold(std::size_t k, float score) -> bool; /// Returns a query ready to be used for retrieval. - [[nodiscard]] auto query() const -> QueryRequest; + [[nodiscard]] auto query(std::size_t k) const -> QueryRequest; private: QueryContainer(); diff --git a/src/query.cpp b/src/query.cpp index 1b54efe66..da943bd24 100644 --- a/src/query.cpp +++ b/src/query.cpp @@ -9,7 +9,13 @@ namespace pisa { -QueryRequest::QueryRequest(QueryContainer const& data) : m_threshold(data.threshold()) +[[nodiscard]] auto first_equal_to(std::size_t k) +{ + return [k](auto&& pair) { return pair.first == k; }; +} + +QueryRequest::QueryRequest(QueryContainer const& data, std::size_t k) + : m_k(k), m_threshold(data.threshold(k)) { if (auto term_ids = data.term_ids(); term_ids) { m_term_ids = *term_ids; @@ -35,13 +41,13 @@ struct QueryContainerInner { std::optional query_string; std::optional> processed_terms; std::optional> term_ids; - std::optional threshold; + std::vector> thresholds; [[nodiscard]] auto operator==(QueryContainerInner const& other) const noexcept -> bool { return id == other.id && query_string == other.query_string && processed_terms == other.processed_terms && term_ids == other.term_ids - && threshold == other.threshold; + && thresholds == other.thresholds; } }; @@ -115,9 +121,18 @@ auto QueryContainer::term_ids() const noexcept -> std::optionalterm_ids; } -auto QueryContainer::threshold() const noexcept -> std::optional const& +auto QueryContainer::threshold(std::size_t k) const noexcept -> std::optional +{ + auto pos = std::find_if(m_data->thresholds.begin(), m_data->thresholds.end(), first_equal_to(k)); + if (pos == m_data->thresholds.end()) { + return std::nullopt; + } + return std::make_optional(pos->second); +} + +auto QueryContainer::thresholds() const noexcept -> std::vector> const& { - return m_data->threshold; + return m_data->thresholds; } auto QueryContainer::string(std::string raw_query) -> QueryContainer& @@ -143,15 +158,21 @@ auto QueryContainer::parse(ParseFn parse_fn) -> QueryContainer& return *this; } -auto QueryContainer::threshold(float score) -> QueryContainer& +auto QueryContainer::add_threshold(std::size_t k, float score) -> bool { - m_data->threshold = score; - return *this; + if (auto pos = + std::find_if(m_data->thresholds.begin(), m_data->thresholds.end(), first_equal_to(k)); + pos != m_data->thresholds.end()) { + pos->second = score; + return true; + } + m_data->thresholds.emplace_back(k, score); + return false; } -auto QueryContainer::query() const -> QueryRequest +auto QueryContainer::query(std::size_t k) const -> QueryRequest { - return QueryRequest(*this); + return QueryRequest(*this, k); } template @@ -189,8 +210,25 @@ auto QueryContainer::from_json(std::string_view json_string) -> QueryContainer data.term_ids = std::move(term_ids); at_least_one_required = true; } - if (auto threshold = get(json, "threshold"); threshold) { - data.threshold = threshold; + if (auto thresholds = json.find("thresholds"); thresholds != json.end()) { + auto raise_error = [&]() { + throw std::runtime_error( + fmt::format("Field \"thresholds\" is invalid: {}", thresholds->dump())); + }; + if (not thresholds->is_array()) { + raise_error(); + } + for (auto&& threshold_entry: *thresholds) { + if (not threshold_entry.is_object()) { + raise_error(); + } + auto k = get(threshold_entry, "k"); + auto score = get(threshold_entry, "score"); + if (not k or not score) { + raise_error(); + } + data.thresholds.emplace_back(*k, *score); + } } if (not at_least_one_required) { throw std::invalid_argument(fmt::format( @@ -218,8 +256,15 @@ auto QueryContainer::to_json() const -> std::string if (auto term_ids = m_data->term_ids; term_ids) { json["term_ids"] = *term_ids; } - if (auto threshold = m_data->threshold; threshold) { - json["threshold"] = *threshold; + if (not m_data->thresholds.empty()) { + auto thresholds = nlohmann::json::array(); + for (auto&& [k, score]: m_data->thresholds) { + auto entry = nlohmann::json::object(); + entry["k"] = k; + entry["score"] = score; + thresholds.push_back(std::move(entry)); + } + json["thresholds"] = thresholds; } return json.dump(); } diff --git a/test/test_query.cpp b/test/test_query.cpp index 00b3aea0e..69c7406d8 100644 --- a/test/test_query.cpp +++ b/test/test_query.cpp @@ -107,18 +107,18 @@ TEST_CASE("Parse query container from JSON") REQUIRE(*query.string() == "brooklyn tea house"); REQUIRE_FALSE(query.terms()); REQUIRE_FALSE(query.term_ids()); - REQUIRE_FALSE(query.threshold()); + REQUIRE(query.thresholds().empty()); query = QueryContainer::from_json(R"( { "term_ids": [1, 0, 3], "terms": ["brooklyn", "tea", "house"], - "threshold": 10.8 + "thresholds": [{"k": 10, "score": 10.8}] } )"); REQUIRE(*query.terms() == std::vector{"brooklyn", "tea", "house"}); REQUIRE(*query.term_ids() == std::vector{1, 0, 3}); - REQUIRE(*query.threshold() == Approx(10.8)); + REQUIRE(*query.threshold(10) == Approx(10.8)); REQUIRE_FALSE(query.id()); REQUIRE_FALSE(query.string()); @@ -133,11 +133,11 @@ TEST_CASE("Serialize query container to JSON") "query": "brooklyn tea house", "terms": ["brooklyn", "tea", "house"], "term_ids": [1, 0, 3], - "threshold": 10.0 + "thresholds": [{"k": 10, "score": 10.0}] } )"); auto serialized = query.to_json(); REQUIRE( serialized - == R"({"id":"ID","query":"brooklyn tea house","term_ids":[1,0,3],"terms":["brooklyn","tea","house"],"threshold":10.0})"); + == R"({"id":"ID","query":"brooklyn tea house","term_ids":[1,0,3],"terms":["brooklyn","tea","house"],"thresholds":[{"k":10,"score":10.0}]})"); } From 78cf15c6a788ea043b7b59d5687d533cf2b9736f Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Sun, 3 May 2020 16:01:36 +0000 Subject: [PATCH 11/12] Return program with 1 if fails --- tools/filter_queries.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/filter_queries.cpp b/tools/filter_queries.cpp index dbc89469e..99a609d5d 100644 --- a/tools/filter_queries.cpp +++ b/tools/filter_queries.cpp @@ -43,9 +43,9 @@ int main(int argc, char** argv) app.query_file(), std::move(term_resolver), min_query_len, max_query_len, std::cout); return 0; } catch (pisa::MissingResolverError err) { - spdlog::error("Unresoved queries(without IDs) require term lexicon."); + spdlog::error("Unresoved queries (without IDs) require term lexicon."); } catch (std::runtime_error const& err) { spdlog::error(err.what()); - return 1; } + return 1; } From 83f9c74f6a1b6bdd605c36f70c473434a6973401 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Fri, 22 May 2020 00:49:57 +0000 Subject: [PATCH 12/12] Fix merging issue --- src/query/term_resolver.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/query/term_resolver.cpp b/src/query/term_resolver.cpp index 7d00f59ad..3917946ff 100644 --- a/src/query/term_resolver.cpp +++ b/src/query/term_resolver.cpp @@ -39,7 +39,7 @@ StandardTermResolver::StandardTermResolver( return std::nullopt; }; - m_self->transform = pisa::term_processor(stemmer_type); + m_self->transform = pisa::term_processor_builder(stemmer_type)(); if (stopwords_filename) { std::ifstream is(*stopwords_filename);