diff --git a/CMakeLists.txt b/CMakeLists.txt index 86e0f2460..769fee467 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -105,6 +105,7 @@ target_link_libraries(pisa PUBLIC # TODO(michal): are there any of these we can spdlog fmt::fmt range-v3 + nlohmann_json::nlohmann_json ) target_include_directories(pisa PUBLIC external) diff --git a/include/pisa/query.hpp b/include/pisa/query.hpp new file mode 100644 index 000000000..dbec6ba51 --- /dev/null +++ b/include/pisa/query.hpp @@ -0,0 +1,149 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +namespace pisa { + +struct QueryContainerInner; + +struct ResolvedTerm { + std::uint32_t id; + std::string term; +}; + +using TermProcessorFn = std::function(std::string)>; +using ParseFn = std::function(std::string const&)>; + +class QueryContainer; + +/// QueryRequest is a special container that maintains important invariants, such as sorted term +/// IDs, and also has some additional data, like term weights, etc. +class QueryRequest { + public: + explicit QueryRequest(QueryContainer const& data, std::size_t k); + + [[nodiscard]] auto term_ids() const -> gsl::span; + [[nodiscard]] auto threshold() const -> std::optional; + [[nodiscard]] auto k() const -> std::optional; + + private: + std::size_t m_k; + std::optional m_threshold{}; + std::vector m_term_ids{}; +}; + +class QueryContainer { + public: + QueryContainer(QueryContainer const&); + QueryContainer(QueryContainer&&) noexcept; + QueryContainer& operator=(QueryContainer const&); + QueryContainer& operator=(QueryContainer&&) noexcept; + ~QueryContainer(); + + [[nodiscard]] auto operator==(QueryContainer const& other) const noexcept -> bool; + + /// Constructs a query from a raw string. + [[nodiscard]] static auto raw(std::string query_string) -> QueryContainer; + + /// Constructs a query from a list of terms. + /// + /// \param terms List of terms + /// \param term_processor Function executed for each term before stroring them, + /// e.g., stemming or filtering. This function returns + /// `std::optional`, and all `std::nullopt` values + /// will be filtered out. + [[nodiscard]] static auto + from_terms(std::vector terms, std::optional term_processor) + -> QueryContainer; + + /// Constructs a query from a list of term IDs. + [[nodiscard]] static auto from_term_ids(std::vector term_ids) -> QueryContainer; + + /// Constructs a query from a JSON object. + [[nodiscard]] static auto from_json(std::string_view json_string) -> QueryContainer; + + [[nodiscard]] auto to_json() const -> std::string; + + /// Constructs a query from a colon-separated format: + /// + /// ``` + /// id:raw query string + /// ``` + /// or + /// ``` + /// raw query string + /// ``` + [[nodiscard]] static auto from_colon_format(std::string_view line) -> QueryContainer; + + // Accessors + + [[nodiscard]] auto id() const noexcept -> std::optional const&; + [[nodiscard]] auto string() const noexcept -> std::optional const&; + [[nodiscard]] auto terms() const noexcept -> std::optional> const&; + [[nodiscard]] auto term_ids() const noexcept -> std::optional> const&; + [[nodiscard]] auto threshold(std::size_t k) const noexcept -> std::optional; + [[nodiscard]] auto thresholds() const noexcept + -> std::vector> const&; + + /// Sets the raw string. + [[nodiscard]] auto string(std::string) -> QueryContainer&; + + /// Parses the raw query with the given parser. + /// + /// \throws std::domain_error when raw string is not set + auto parse(ParseFn parse_fn) -> QueryContainer&; + + /// Sets the query score threshold for `k`. + /// + /// If another threshold for the same `k` exists, it will be replaced, + /// and `true` will be returned. Otherwise, `false` will be returned. + auto add_threshold(std::size_t k, float score) -> bool; + + /// Returns a query ready to be used for retrieval. + [[nodiscard]] auto query(std::size_t k) const -> QueryRequest; + + private: + QueryContainer(); + std::unique_ptr m_data; +}; + +enum class Format { Json, Colon }; + +class QueryReader { + public: + /// Open reader from file. + static auto from_file(std::string const& file) -> QueryReader; + /// Open reader from stdin. + static auto from_stdin() -> QueryReader; + + /// Read next query or return `nullopt` if stream has ended. + [[nodiscard]] auto next() -> std::optional; + + /// Execute `fn(q)` for each query `q`. + template + void for_each(Fn&& fn) + { + auto query = next(); + while (query) { + fn(std::move(*query)); + query = next(); + } + } + + private: + explicit QueryReader(std::unique_ptr stream, std::istream& stream_ref); + + std::unique_ptr m_stream; + std::istream& m_stream_ref; + std::string m_line_buf{}; + std::optional m_format{}; +}; + +} // namespace pisa diff --git a/include/pisa/query/query_parser.hpp b/include/pisa/query/query_parser.hpp new file mode 100644 index 000000000..a3500962a --- /dev/null +++ b/include/pisa/query/query_parser.hpp @@ -0,0 +1,26 @@ +#pragma once + +#include + +#include "query.hpp" +#include "term_resolver.hpp" + +namespace pisa { + +/// Parses a query string to processed terms. +class QueryParser { + public: + explicit QueryParser(TermResolver term_processor); + /// Given a query string, it returns a list of (possibly processed) terms. + /// + /// Possible transformations of terms include lower-casing and stemming. + /// Some terms could be also removed, e.g., because they are on a list of + /// stop words. The exact implementation depends on the term processor + /// passed to the constructor. + auto operator()(std::string const&) -> std::vector; + + private: + TermResolver m_term_resolver; +}; + +} // namespace pisa diff --git a/include/pisa/query/term_resolver.hpp b/include/pisa/query/term_resolver.hpp new file mode 100644 index 000000000..fc9088a48 --- /dev/null +++ b/include/pisa/query/term_resolver.hpp @@ -0,0 +1,53 @@ +#pragma once + +#include +#include +#include + +#include "query.hpp" + +namespace pisa { + +/// Thrown if expected resolver but none found. +struct MissingResolverError { +}; + +using TermResolver = std::function(std::string)>; + +struct StandardTermResolverParams; + +/// Provides a standard implementation of `TermResolver`. +class StandardTermResolver { + public: + StandardTermResolver( + std::string const& term_lexicon_path, + std::optional const& stopwords_filename, + std::optional const& stemmer_type); + StandardTermResolver(StandardTermResolver const&); + StandardTermResolver(StandardTermResolver&&) noexcept; + StandardTermResolver& operator=(StandardTermResolver const&); + StandardTermResolver& operator=(StandardTermResolver&&) noexcept; + ~StandardTermResolver(); + + [[nodiscard]] auto operator()(std::string token) const -> std::optional; + + private: + [[nodiscard]] auto is_stopword(std::uint32_t const term) const -> bool; + + std::unique_ptr m_self; +}; + +/// Reads queries from `query_file`, resolves them with `term_resolver`, filters by +/// query length (number of resolved terms in the query), and prints the selected +/// queries to `out`. +/// +/// \throws MissingResolverError When no resolver passed but queries don't have IDs resolved. +// +void filter_queries( + std::optional const& query_file, + std::optional term_resolver, + std::size_t min_query_len, + std::size_t max_query_len, + std::ostream& out); + +} // namespace pisa diff --git a/src/query.cpp b/src/query.cpp new file mode 100644 index 000000000..da943bd24 --- /dev/null +++ b/src/query.cpp @@ -0,0 +1,323 @@ +#include "query.hpp" + +#include +#include +#include + +#include +#include + +namespace pisa { + +[[nodiscard]] auto first_equal_to(std::size_t k) +{ + return [k](auto&& pair) { return pair.first == k; }; +} + +QueryRequest::QueryRequest(QueryContainer const& data, std::size_t k) + : m_k(k), m_threshold(data.threshold(k)) +{ + if (auto term_ids = data.term_ids(); term_ids) { + m_term_ids = *term_ids; + std::sort(m_term_ids.begin(), m_term_ids.end()); + auto last = std::unique(m_term_ids.begin(), m_term_ids.end()); + m_term_ids.erase(last, m_term_ids.end()); + } + throw std::domain_error("Query not parsed."); +} + +auto QueryRequest::term_ids() const -> gsl::span +{ + return gsl::span(m_term_ids); +} + +auto QueryRequest::threshold() const -> std::optional +{ + return m_threshold; +} + +struct QueryContainerInner { + std::optional id; + std::optional query_string; + std::optional> processed_terms; + std::optional> term_ids; + std::vector> thresholds; + + [[nodiscard]] auto operator==(QueryContainerInner const& other) const noexcept -> bool + { + return id == other.id && query_string == other.query_string + && processed_terms == other.processed_terms && term_ids == other.term_ids + && thresholds == other.thresholds; + } +}; + +QueryContainer::QueryContainer() : m_data(std::make_unique()) {} + +QueryContainer::QueryContainer(QueryContainer const& other) + : m_data(std::make_unique(*other.m_data)) +{} +QueryContainer::QueryContainer(QueryContainer&&) noexcept = default; +QueryContainer& QueryContainer::operator=(QueryContainer const& other) +{ + this->m_data = std::make_unique(*other.m_data); + return *this; +} +QueryContainer& QueryContainer::operator=(QueryContainer&&) noexcept = default; +QueryContainer::~QueryContainer() = default; + +auto QueryContainer::operator==(QueryContainer const& other) const noexcept -> bool +{ + return *m_data == *other.m_data; +} + +auto QueryContainer::raw(std::string query_string) -> QueryContainer +{ + QueryContainer query; + query.m_data->query_string = std::move(query_string); + return query; +} + +auto QueryContainer::from_terms( + std::vector terms, std::optional term_processor) -> QueryContainer +{ + QueryContainer query; + query.m_data->processed_terms = std::vector{}; + auto& processed_terms = *query.m_data->processed_terms; + for (auto&& term: terms) { + if (term_processor) { + auto fn = *term_processor; + if (auto processed = fn(std::move(term)); processed) { + processed_terms.push_back(std::move(*processed)); + } + } else { + processed_terms.push_back(std::move(term)); + } + } + return query; +} + +auto QueryContainer::from_term_ids(std::vector term_ids) -> QueryContainer +{ + QueryContainer query; + query.m_data->term_ids = std::move(term_ids); + return query; +} + +auto QueryContainer::id() const noexcept -> std::optional const& +{ + return m_data->id; +} +auto QueryContainer::string() const noexcept -> std::optional const& +{ + return m_data->query_string; +} +auto QueryContainer::terms() const noexcept -> std::optional> const& +{ + return m_data->processed_terms; +} + +auto QueryContainer::term_ids() const noexcept -> std::optional> const& +{ + return m_data->term_ids; +} + +auto QueryContainer::threshold(std::size_t k) const noexcept -> std::optional +{ + auto pos = std::find_if(m_data->thresholds.begin(), m_data->thresholds.end(), first_equal_to(k)); + if (pos == m_data->thresholds.end()) { + return std::nullopt; + } + return std::make_optional(pos->second); +} + +auto QueryContainer::thresholds() const noexcept -> std::vector> const& +{ + return m_data->thresholds; +} + +auto QueryContainer::string(std::string raw_query) -> QueryContainer& +{ + m_data->query_string = std::move(raw_query); + return *this; +} + +auto QueryContainer::parse(ParseFn parse_fn) -> QueryContainer& +{ + if (not m_data->query_string) { + throw std::domain_error("Cannot parse, query string not set"); + } + auto parsed_terms = parse_fn(*m_data->query_string); + std::vector processed_terms; + std::vector term_ids; + for (auto&& term: parsed_terms) { + processed_terms.push_back(std::move(term.term)); + term_ids.push_back(term.id); + } + m_data->term_ids = std::move(term_ids); + m_data->processed_terms = std::move(processed_terms); + return *this; +} + +auto QueryContainer::add_threshold(std::size_t k, float score) -> bool +{ + if (auto pos = + std::find_if(m_data->thresholds.begin(), m_data->thresholds.end(), first_equal_to(k)); + pos != m_data->thresholds.end()) { + pos->second = score; + return true; + } + m_data->thresholds.emplace_back(k, score); + return false; +} + +auto QueryContainer::query(std::size_t k) const -> QueryRequest +{ + return QueryRequest(*this, k); +} + +template +[[nodiscard]] auto get(nlohmann::json const& node, std::string_view field) -> std::optional +{ + if (auto pos = node.find(field); pos != node.end()) { + try { + return std::make_optional(pos->get()); + } catch (nlohmann::detail::exception const& err) { + throw std::runtime_error(fmt::format("Requested field {} is of wrong type", field)); + } + } + return std::optional{}; +} + +auto QueryContainer::from_json(std::string_view json_string) -> QueryContainer +{ + try { + auto json = nlohmann::json::parse(json_string); + QueryContainer query; + QueryContainerInner& data = *query.m_data; + bool at_least_one_required = false; + if (auto id = get(json, "id"); id) { + data.id = std::move(id); + } + if (auto raw = get(json, "query"); raw) { + data.query_string = std::move(raw); + at_least_one_required = true; + } + if (auto terms = get>(json, "terms"); terms) { + data.processed_terms = std::move(terms); + at_least_one_required = true; + } + if (auto term_ids = get>(json, "term_ids"); term_ids) { + data.term_ids = std::move(term_ids); + at_least_one_required = true; + } + if (auto thresholds = json.find("thresholds"); thresholds != json.end()) { + auto raise_error = [&]() { + throw std::runtime_error( + fmt::format("Field \"thresholds\" is invalid: {}", thresholds->dump())); + }; + if (not thresholds->is_array()) { + raise_error(); + } + for (auto&& threshold_entry: *thresholds) { + if (not threshold_entry.is_object()) { + raise_error(); + } + auto k = get(threshold_entry, "k"); + auto score = get(threshold_entry, "score"); + if (not k or not score) { + raise_error(); + } + data.thresholds.emplace_back(*k, *score); + } + } + if (not at_least_one_required) { + throw std::invalid_argument(fmt::format( + "JSON must have either raw query, terms, or term IDs: {}", json_string)); + } + return query; + } catch (nlohmann::detail::exception const& err) { + throw std::runtime_error( + fmt::format("Failed to parse JSON: `{}`: {}", json_string, err.what())); + } +} + +auto QueryContainer::to_json() const -> std::string +{ + nlohmann::json json; + if (auto id = m_data->id; id) { + json["id"] = *id; + } + if (auto raw = m_data->query_string; raw) { + json["query"] = *raw; + } + if (auto terms = m_data->processed_terms; terms) { + json["terms"] = *terms; + } + if (auto term_ids = m_data->term_ids; term_ids) { + json["term_ids"] = *term_ids; + } + if (not m_data->thresholds.empty()) { + auto thresholds = nlohmann::json::array(); + for (auto&& [k, score]: m_data->thresholds) { + auto entry = nlohmann::json::object(); + entry["k"] = k; + entry["score"] = score; + thresholds.push_back(std::move(entry)); + } + json["thresholds"] = thresholds; + } + return json.dump(); +} + +auto QueryContainer::from_colon_format(std::string_view line) -> QueryContainer +{ + auto pos = std::find(line.begin(), line.end(), ':'); + QueryContainer query; + QueryContainerInner& data = *query.m_data; + if (pos == line.end()) { + data.query_string = std::string(line); + } else { + data.id = std::string(line.begin(), pos); + data.query_string = std::string(std::next(pos), line.end()); + } + return query; +} + +auto QueryReader::from_file(std::string const& file) -> QueryReader +{ + auto input = std::make_unique(file); + auto& ref = *input; + return QueryReader(std::move(input), ref); +} + +auto QueryReader::from_stdin() -> QueryReader +{ + return QueryReader(nullptr, std::cin); +} + +QueryReader::QueryReader(std::unique_ptr input, std::istream& stream_ref) + : m_stream(std::move(input)), m_stream_ref(stream_ref) +{} + +auto QueryReader::next() -> std::optional +{ + if (std::getline(m_stream_ref, m_line_buf)) { + if (m_format) { + if (*m_format == Format::Json) { + return QueryContainer::from_json(m_line_buf); + } + return QueryContainer::from_colon_format(m_line_buf); + } + try { + auto query = QueryContainer::from_json(m_line_buf); + m_format = Format::Json; + return query; + } catch (std::exception const& err) { + m_format = Format::Colon; + return QueryContainer::from_colon_format(m_line_buf); + } + } + return std::nullopt; +} + +} // namespace pisa diff --git a/src/query/query_parser.cpp b/src/query/query_parser.cpp new file mode 100644 index 000000000..fc64a7891 --- /dev/null +++ b/src/query/query_parser.cpp @@ -0,0 +1,27 @@ +#include + +#include "io.hpp" +#include "payload_vector.hpp" +#include "query.hpp" +#include "query/query_parser.hpp" +#include "query/term_resolver.hpp" +#include "tokenizer.hpp" + +namespace pisa { + +QueryParser::QueryParser(TermResolver term_resolver) : m_term_resolver(std::move(term_resolver)) {} + +auto QueryParser::operator()(std::string const& query) -> std::vector +{ + TermTokenizer tokenizer(query); + std::vector terms; + for (auto term_iter = tokenizer.begin(); term_iter != tokenizer.end(); ++term_iter) { + auto term = m_term_resolver(*term_iter); + if (term) { + terms.push_back(std::move(*term)); + } + } + return terms; +} + +} // namespace pisa diff --git a/src/query/term_resolver.cpp b/src/query/term_resolver.cpp new file mode 100644 index 000000000..3917946ff --- /dev/null +++ b/src/query/term_resolver.cpp @@ -0,0 +1,100 @@ +#include "query/term_resolver.hpp" +#include "query/query_parser.hpp" +#include "query/term_processor.hpp" + +namespace pisa { + +StandardTermResolver::StandardTermResolver(StandardTermResolver const& other) + : m_self(std::make_unique(*other.m_self)) +{} +StandardTermResolver::StandardTermResolver(StandardTermResolver&&) noexcept = default; +StandardTermResolver& StandardTermResolver::operator=(StandardTermResolver const& other) +{ + m_self = std::make_unique(*other.m_self); + return *this; +} +StandardTermResolver& StandardTermResolver::operator=(StandardTermResolver&&) noexcept = default; +StandardTermResolver::~StandardTermResolver() = default; + +struct StandardTermResolverParams { + std::vector stopwords; + std::function(std::string const&)> to_id; + std::function transform; +}; + +StandardTermResolver::StandardTermResolver( + std::string const& term_lexicon_path, + std::optional const& stopwords_filename, + std::optional const& stemmer_type) + : m_self(std::make_unique()) +{ + auto source = std::make_shared(term_lexicon_path.c_str()); + auto terms = pisa::Payload_Vector<>::from(*source); + + m_self->to_id = [source = std::move(source), terms](auto str) -> std::optional { + auto pos = std::lower_bound(terms.begin(), terms.end(), std::string_view(str)); + if (*pos == std::string_view(str)) { + return std::distance(terms.begin(), pos); + } + return std::nullopt; + }; + + m_self->transform = pisa::term_processor_builder(stemmer_type)(); + + if (stopwords_filename) { + std::ifstream is(*stopwords_filename); + pisa::io::for_each_line(is, [&](auto&& word) { + if (auto term_id = m_self->to_id(std::move(word)); term_id.has_value()) { + m_self->stopwords.push_back(*term_id); + } + }); + std::sort(m_self->stopwords.begin(), m_self->stopwords.end()); + } +} + +auto StandardTermResolver::operator()(std::string token) const -> std::optional +{ + token = m_self->transform(token); + auto id = m_self->to_id(token); + if (not id) { + return std::nullopt; + } + if (is_stopword(*id)) { + return std::nullopt; + } + return pisa::ResolvedTerm{*id, token}; +} + +auto StandardTermResolver::is_stopword(std::uint32_t const term) const -> bool +{ + auto pos = std::lower_bound(m_self->stopwords.begin(), m_self->stopwords.end(), term); + return pos != m_self->stopwords.end() && *pos == term; +} + +void filter_queries( + std::optional const& query_file, + std::optional term_resolver, + std::size_t min_query_len, + std::size_t max_query_len, + std::ostream& out) +{ + auto reader = [&] { + if (query_file) { + return QueryReader::from_file(*query_file); + } + return QueryReader::from_stdin(); + }(); + reader.for_each([&](auto query) { + if (not query.term_ids()) { + if (not term_resolver) { + throw MissingResolverError{}; + } + query.parse(QueryParser(*term_resolver)); + } + if (auto len = query.term_ids()->size(); len >= min_query_len && len <= max_query_len) { + out << query.to_json() << '\n'; + } + }); +} + +} // namespace pisa diff --git a/test/cli/run.sh b/test/cli/run.sh new file mode 100755 index 000000000..51486a077 --- /dev/null +++ b/test/cli/run.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash + +DIR=$(dirname "$0") +$DIR/setup.sh +bats $DIR/test_filter_queries.sh diff --git a/test/cli/setup.sh b/test/cli/setup.sh new file mode 100755 index 000000000..c55753f1a --- /dev/null +++ b/test/cli/setup.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash + +# This script should be executed within the build directory that is directly +# in the project directory, e.g., /path/to/pisa/build + +PISA_BIN="./bin" +export PATH="$PISA_BIN:$PATH" + +cat "../test/test_data/clueweb1k.plaintext" | parse_collection \ + --stemmer porter2 \ + --output "./fwd" \ + --format plaintext + +invert --input "./fwd" --output "./inv" + +compress_inverted_index --check \ + --encoding block_simdbp \ + --collection "./inv" \ + --output "./simdbp" + +create_wand_data \ + --scorer bm25 \ + --collection "./inv" \ + --output "./bm25.bmw" \ + --block-size 32 diff --git a/test/cli/test_filter_queries.sh b/test/cli/test_filter_queries.sh new file mode 100644 index 000000000..2fa486074 --- /dev/null +++ b/test/cli/test_filter_queries.sh @@ -0,0 +1,74 @@ +#!/usr/bin/env bats + +PISA_BIN="bin" +export PATH="$PISA_BIN:$PATH" + +function write_lines { + file=$1 + rm -f "$file" + shift + for line in "$@" + do + echo "$line" >> "$file" + done +} + + +function setup { + write_lines "$BATS_TMPDIR/queries.txt" "brooklyn tea house" "labradoodle" 'Tell your dog I said "hi"' + write_lines "$BATS_TMPDIR/stopwords.txt" "i" "your" +} + +@test "Filter from plain" { + result=$(cat $BATS_TMPDIR/queries.txt | filter-queries --stemmer porter2 --terms ./fwd.termlex) + expected='{"query":"brooklyn tea house","term_ids":[6535,29194,15462],"terms":["brooklyn","tea","hous"]} +{"query":"Tell your dog I said \"hi\"","term_ids":[29287,32766,10396,15670,26032,15114],"terms":["tell","your","dog","i","said","hi"]}' + [[ "$result" = "$expected" ]] +} + +@test "Filter with minimum length" { + result=$(cat $BATS_TMPDIR/queries.txt | filter-queries --min 4 --stemmer porter2 --terms ./fwd.termlex) + expected='{"query":"Tell your dog I said \"hi\"","term_ids":[29287,32766,10396,15670,26032,15114],"terms":["tell","your","dog","i","said","hi"]}' + [[ "$result" = "$expected" ]] +} + +@test "Filter with maximum length" { + result=$(cat $BATS_TMPDIR/queries.txt | filter-queries --max 4 --stemmer porter2 --terms ./fwd.termlex) + expected='{"query":"brooklyn tea house","term_ids":[6535,29194,15462],"terms":["brooklyn","tea","hous"]}' + [[ "$result" = "$expected" ]] +} + +@test "Filter with stopwords" { + result=$(cat $BATS_TMPDIR/queries.txt | filter-queries --stopwords "$BATS_TMPDIR/stopwords.txt" --stemmer porter2 --terms ./fwd.termlex) + expected='{"query":"brooklyn tea house","term_ids":[6535,29194,15462],"terms":["brooklyn","tea","hous"]} +{"query":"Tell your dog I said \"hi\"","term_ids":[29287,10396,26032,15114],"terms":["tell","dog","said","hi"]}' + [[ "$result" = "$expected" ]] +} + +@test "Filter without stemmer" { + result=$(cat $BATS_TMPDIR/queries.txt | filter-queries --terms ./fwd.termlex) + expected='{"query":"brooklyn tea house","term_ids":[6535,29194],"terms":["brooklyn","tea"]} +{"query":"Tell your dog I said \"hi\"","term_ids":[29287,32766,10396,15670,26032,15114],"terms":["tell","your","dog","i","said","hi"]}' + [[ "$result" = "$expected" ]] +} + +@test "Accept JSON" { + echo '{"query":"brooklyn tea house"}' > "$BATS_TMPDIR/queries.json" + result=$(cat $BATS_TMPDIR/queries.json | filter-queries --terms ./fwd.termlex) + expected='{"query":"brooklyn tea house","term_ids":[6535,29194],"terms":["brooklyn","tea"]}' + [[ "$result" = "$expected" ]] +} + +@test "Accept JSON without --terms if already parsed" { + echo '{"term_ids":[6535,29194]}' > "$BATS_TMPDIR/queries.json" + result=$(cat $BATS_TMPDIR/queries.json | filter-queries) + expected='{"term_ids":[6535,29194]}' + [[ "$result" = "$expected" ]] +} + +@test "Fail when no --terms and not parsed" { + echo '{"query":"brooklyn tea house"}' > "$BATS_TMPDIR/queries.json" + run filter-queries < $BATS_TMPDIR/queries.json + [[ "$status" -eq 1 ]] + [[ "$output" = *"[error] Unresoved queries (without IDs) require term lexicon." ]] +} diff --git a/test/test_query.cpp b/test/test_query.cpp new file mode 100644 index 000000000..69c7406d8 --- /dev/null +++ b/test/test_query.cpp @@ -0,0 +1,143 @@ +#define CATCH_CONFIG_MAIN + +#include + +#include + +#include "query.hpp" + +using pisa::QueryContainer; + +TEST_CASE("Construct from raw string") +{ + auto raw_query = "brooklyn tea house"; + auto query = QueryContainer::raw(raw_query); + REQUIRE(*query.string() == raw_query); +} + +TEST_CASE("Construct from terms") +{ + std::vector terms{"brooklyn", "tea", "house"}; + auto query = QueryContainer::from_terms(terms, std::nullopt); + REQUIRE(*query.terms() == std::vector{"brooklyn", "tea", "house"}); +} + +TEST_CASE("Construct from terms with processor") +{ + std::vector terms{"brooklyn", "tea", "house"}; + auto proc = [](std::string term) -> std::optional { + if (term.size() > 3) { + return term.substr(0, 4); + } + return std::nullopt; + }; + auto query = QueryContainer::from_terms(terms, proc); + REQUIRE(*query.terms() == std::vector{"broo", "hous"}); +} + +TEST_CASE("Construct from term IDs") +{ + std::vector term_ids{1, 0, 3}; + auto query = QueryContainer::from_term_ids(term_ids); + REQUIRE(*query.term_ids() == std::vector{1, 0, 3}); +} + +TEST_CASE("Parse query") +{ + auto raw_query = "brooklyn tea house brooklyn"; + auto query = QueryContainer::raw(raw_query); + std::vector lexicon{"house", "brooklyn"}; + auto term_proc = [](std::string term) -> std::optional { return term; }; + query.parse([&](auto&& q) { + std::istringstream is(q); + std::string term; + std::vector parsed_terms; + while (is >> term) { + if (auto t = term_proc(term); t) { + if (auto pos = std::find(lexicon.begin(), lexicon.end(), *t); pos != lexicon.end()) { + auto id = static_cast(std::distance(lexicon.begin(), pos)); + parsed_terms.push_back(pisa::ResolvedTerm{id, *t}); + } + } + } + return parsed_terms; + }); + REQUIRE(*query.term_ids() == std::vector{1, 0, 1}); +} + +TEST_CASE("Parsing throws without raw query") +{ + std::vector term_ids{1, 0, 3}; + auto query = QueryContainer::from_term_ids(term_ids); + REQUIRE_THROWS_AS( + query.parse([](auto&& str) { return std::vector{}; }), std::domain_error); +} + +TEST_CASE("Parse query container from colon-delimited format") +{ + auto query = QueryContainer::from_colon_format(""); + REQUIRE(query.string()->empty()); + REQUIRE_FALSE(query.id()); + + query = QueryContainer::from_colon_format("brooklyn tea house"); + REQUIRE(*query.string() == "brooklyn tea house"); + REQUIRE_FALSE(query.id()); + + query = QueryContainer::from_colon_format("BTH:brooklyn tea house"); + REQUIRE(*query.string() == "brooklyn tea house"); + REQUIRE(*query.id() == "BTH"); + + query = QueryContainer::from_colon_format("BTH:"); + REQUIRE(query.string()->empty()); + REQUIRE(*query.id() == "BTH"); +} + +TEST_CASE("Parse query container from JSON") +{ + REQUIRE_THROWS_AS(QueryContainer::from_json(""), std::runtime_error); + REQUIRE_THROWS_AS(QueryContainer::from_json(R"({"id":"ID"})"), std::invalid_argument); + + auto query = QueryContainer::from_json(R"( + { + "id": "ID", + "query": "brooklyn tea house" + } + )"); + REQUIRE(*query.id() == "ID"); + REQUIRE(*query.string() == "brooklyn tea house"); + REQUIRE_FALSE(query.terms()); + REQUIRE_FALSE(query.term_ids()); + REQUIRE(query.thresholds().empty()); + + query = QueryContainer::from_json(R"( + { + "term_ids": [1, 0, 3], + "terms": ["brooklyn", "tea", "house"], + "thresholds": [{"k": 10, "score": 10.8}] + } + )"); + REQUIRE(*query.terms() == std::vector{"brooklyn", "tea", "house"}); + REQUIRE(*query.term_ids() == std::vector{1, 0, 3}); + REQUIRE(*query.threshold(10) == Approx(10.8)); + REQUIRE_FALSE(query.id()); + REQUIRE_FALSE(query.string()); + + REQUIRE_THROWS_AS(QueryContainer::from_json(R"({"terms":[1, 2]})"), std::runtime_error); +} + +TEST_CASE("Serialize query container to JSON") +{ + auto query = QueryContainer::from_json(R"( + { + "id": "ID", + "query": "brooklyn tea house", + "terms": ["brooklyn", "tea", "house"], + "term_ids": [1, 0, 3], + "thresholds": [{"k": 10, "score": 10.0}] + } + )"); + auto serialized = query.to_json(); + REQUIRE( + serialized + == R"({"id":"ID","query":"brooklyn tea house","term_ids":[1,0,3],"terms":["brooklyn","tea","house"],"thresholds":[{"k":10,"score":10.0}]})"); +} diff --git a/test/test_query_parser.cpp b/test/test_query_parser.cpp new file mode 100644 index 000000000..95f028124 --- /dev/null +++ b/test/test_query_parser.cpp @@ -0,0 +1,28 @@ +#define CATCH_CONFIG_MAIN + +#include + +#include + +#include "query/query_parser.hpp" + +using pisa::QueryContainer; +using pisa::QueryParser; + +TEST_CASE("Parse with lower-case processor and stop word") +{ + std::uint32_t init_id = 0; + auto term_proc = [id = init_id](auto&& term) mutable { + std::transform( + term.begin(), term.end(), term.begin(), [](unsigned char c) { return std::tolower(c); }); + if (term == "house") { + return std::optional{}; + } + return std::optional{pisa::ResolvedTerm{id++, term}}; + }; + QueryParser parser(term_proc); + auto terms = parser("Brooklyn tea house"); + REQUIRE(terms.size() == 2); + REQUIRE(terms[0].term == "brooklyn"); + REQUIRE(terms[1].term == "tea"); +} diff --git a/test/test_term_resolver.cpp b/test/test_term_resolver.cpp new file mode 100644 index 000000000..913060b85 --- /dev/null +++ b/test/test_term_resolver.cpp @@ -0,0 +1,69 @@ +#define CATCH_CONFIG_MAIN + +#include + +#include + +#include "io.hpp" +#include "query/term_resolver.hpp" +#include "temporary_directory.hpp" + +using pisa::QueryContainer; +using pisa::StandardTermResolver; + +TEST_CASE("Filter queries") +{ + std::uint32_t id = 0; + auto term_resolver = [&id](auto&& term) mutable { + return std::optional{pisa::ResolvedTerm{id++, term}}; + }; + Temporary_Directory tmp; + auto input = (tmp.path() / "input.txt"); + { + std::ofstream os(input.c_str()); + os << "a b c d\n"; + os << "e\n"; + os << "f g h i j\n"; + os << "k l m\n"; + os << "n o\n"; + } + + SECTION("Between 2 and 4") + { + std::ostringstream os; + pisa::filter_queries( + std::make_optional(input.string()), std::make_optional(term_resolver), 2, 4, os); + std::vector queries; + std::istringstream is(os.str()); + pisa::io::for_each_line( + is, [&queries](auto&& line) { queries.push_back(QueryContainer::from_json(line)); }); + REQUIRE(queries.size() == 3); + REQUIRE(*queries[0].terms() == std::vector{"a", "b", "c", "d"}); + REQUIRE(*queries[0].term_ids() == std::vector{0, 1, 2, 3}); + REQUIRE(*queries[1].terms() == std::vector{"k", "l", "m"}); + REQUIRE(*queries[1].term_ids() == std::vector{10, 11, 12}); + REQUIRE(*queries[2].terms() == std::vector{"n", "o"}); + REQUIRE(*queries[2].term_ids() == std::vector{13, 14}); + + SECTION("Don't fail if no resolver but IDs already resolved") + { + auto json_input = (tmp.path() / "input.json"); + { + std::ofstream json_out(json_input.c_str()); + for (auto&& query: queries) { + json_out << query.to_json() << '\n'; + } + } + std::ostringstream output; + pisa::filter_queries(std::make_optional(json_input.string()), std::nullopt, 2, 4, output); + REQUIRE(output.str() == os.str()); + } + } + + SECTION("Fail without IDs and resolver") + { + REQUIRE_THROWS_AS( + pisa::filter_queries(std::make_optional(input.string()), std::nullopt, 2, 4, std::cerr), + pisa::MissingResolverError); + } +} diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index c3a1901c4..260a4383c 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -122,6 +122,12 @@ target_link_libraries(reorder-docids CLI11 ) +add_executable(filter-queries filter_queries.cpp) +target_link_libraries(filter-queries + pisa + CLI11 +) + add_executable(kth_threshold kth_threshold.cpp) target_link_libraries(kth_threshold pisa diff --git a/tools/app.hpp b/tools/app.hpp index a94844887..0fe17b3da 100644 --- a/tools/app.hpp +++ b/tools/app.hpp @@ -114,6 +114,21 @@ namespace arg { return q; } + [[nodiscard]] auto term_lexicon() const -> std::optional const& + { + return m_term_lexicon; + } + + [[nodiscard]] auto stemmer() const -> std::optional const& + { + return m_stemmer; + } + + [[nodiscard]] auto stop_words() const -> std::optional const& + { + return m_stop_words; + } + [[nodiscard]] auto k() const -> int { return m_k; } private: diff --git a/tools/filter_queries.cpp b/tools/filter_queries.cpp new file mode 100644 index 000000000..99a609d5d --- /dev/null +++ b/tools/filter_queries.cpp @@ -0,0 +1,51 @@ +#include + +#include +#include +#include + +#include "app.hpp" +#include "query.hpp" +#include "query/query_parser.hpp" +#include "query/term_resolver.hpp" +#include "tokenizer.hpp" + +namespace arg = pisa::arg; + +using pisa::filter_queries; +using pisa::QueryContainer; +using pisa::QueryParser; +using pisa::QueryReader; +using pisa::StandardTermResolver; +using pisa::TermResolver; +using pisa::io::for_each_line; + +int main(int argc, char** argv) +{ + spdlog::drop(""); + spdlog::set_default_logger(spdlog::stderr_color_mt("")); + + std::size_t min_query_len = 1; + std::size_t max_query_len = std::numeric_limits::max(); + + pisa::App> app("Filters queries by their length"); + app.add_option("--min", min_query_len, "Minimum query legth to consider"); + app.add_option("--max", max_query_len, "Maximum query legth to consider"); + CLI11_PARSE(app, argc, argv); + + std::optional term_resolver{}; + if (app.term_lexicon()) { + term_resolver = StandardTermResolver(*app.term_lexicon(), app.stop_words(), app.stemmer()); + } + + try { + filter_queries( + app.query_file(), std::move(term_resolver), min_query_len, max_query_len, std::cout); + return 0; + } catch (pisa::MissingResolverError err) { + spdlog::error("Unresoved queries (without IDs) require term lexicon."); + } catch (std::runtime_error const& err) { + spdlog::error(err.what()); + } + return 1; +}