diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cd27f3c30..53f6bd98a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -61,9 +61,12 @@ jobs: sudo apt-get remove libstdc++-10-dev libgcc-10-dev cpp-10 sudo apt-get remove libstdc++-11-dev libgcc-11-dev cpp-11 fi + sudo add-apt-repository -y 'deb http://mirror.plusserver.com/ubuntu/ubuntu/ bionic-updates main restricted universe multiverse' sudo apt-get update sudo apt-get install -y libtool m4 autoconf - if [ "${{ matrix.compiler }}" = "gcc" ]; then + if [ "${cc}" = "gcc" ]; then + sudo apt-get install -y g++-7/bionic-updates + elif [ "${{ matrix.compiler }}" = "gcc" ]; then sudo apt-get install -y "${cxx}" else echo "TOOLCHAIN=-DCMAKE_TOOLCHAIN_FILE=clang.cmake" >> $GITHUB_ENV diff --git a/benchmarks/index_perftest.cpp b/benchmarks/index_perftest.cpp index 648b93c4e..474b550a8 100644 --- a/benchmarks/index_perftest.cpp +++ b/benchmarks/index_perftest.cpp @@ -9,8 +9,8 @@ using pisa::do_not_optimize_away; using pisa::get_time_usecs; -template -void perftest(IndexType const& index, std::string const& type) +template +void perftest(IndexType&& index, std::string const& type) { std::string freqs_log = with_freqs ? "+freq()" : ""; { @@ -103,18 +103,6 @@ void perftest(IndexType const& index, std::string const& type) } } -template -void perftest(const char* index_filename, std::string const& type) -{ - spdlog::info("Loading index from {}", index_filename); - IndexType index; - mio::mmap_source m(index_filename); - pisa::mapper::map(index, m, pisa::mapper::map_flags::warmup); - - perftest(index, type); - perftest(index, type); -} - int main(int argc, const char** argv) { using namespace pisa; @@ -127,17 +115,13 @@ int main(int argc, const char** argv) std::string type = argv[1]; const char* index_filename = argv[2]; - if (false) { -#define LOOP_BODY(R, DATA, T) \ - } \ - else if (type == BOOST_PP_STRINGIZE(T)) \ - { \ - perftest(index_filename, type); \ - /**/ - - BOOST_PP_SEQ_FOR_EACH(LOOP_BODY, _, PISA_INDEX_TYPES); -#undef LOOP_BODY - } else { - spdlog::error("Unknown type {}", type); + try { + IndexType::resolve(type).load_and_execute(index_filename, [&](auto&& index) { + perftest(index, type); + perftest(index, type); + }); + } catch (std::exception const& err) { + spdlog::error("{}", err.what()); + return 1; } } diff --git a/include/pisa/compress.hpp b/include/pisa/compress.hpp index 9aaa56bf6..d9db50bfe 100644 --- a/include/pisa/compress.hpp +++ b/include/pisa/compress.hpp @@ -246,27 +246,17 @@ void compress( binary_freq_collection input(input_basename.c_str()); global_parameters params; - if (false) { -#define LOOP_BODY(R, DATA, T) \ - } \ - else if (index_encoding == BOOST_PP_STRINGIZE(T)) \ - { \ - compress_index>( \ - input, \ - params, \ - output_filename, \ - check, \ - index_encoding, \ - wand_data_filename, \ - scorer_params, \ - quantize); \ - /**/ - BOOST_PP_SEQ_FOR_EACH(LOOP_BODY, _, PISA_INDEX_TYPES); -#undef LOOP_BODY - } else { - spdlog::error("Unknown type {}", index_encoding); - std::abort(); - } + IndexType::resolve(index_encoding).execute([&](auto type_marker) { + compress_index>( + input, + params, + output_filename, + check, + index_encoding, + wand_data_filename, + scorer_params, + quantize); + }); } } // namespace pisa diff --git a/include/pisa/index_types.hpp b/include/pisa/index_types.hpp index e199cd580..aed6c2254 100644 --- a/include/pisa/index_types.hpp +++ b/include/pisa/index_types.hpp @@ -1,5 +1,7 @@ #pragma once +#include + #include "boost/preprocessor/cat.hpp" #include "boost/preprocessor/seq/for_each.hpp" #include "boost/preprocessor/stringize.hpp" @@ -22,13 +24,11 @@ #include "sequence/uniform_partitioned_sequence.hpp" namespace pisa { -using ef_index = freq_index>; +using ef_index = freq_index>; using single_index = freq_index>; - using pefuniform_index = freq_index, positive_sequence>>; - using pefopt_index = freq_index, positive_sequence>>; @@ -43,12 +43,269 @@ using block_simple8b_index = block_freq_index; using block_simple16_index = block_freq_index; using block_simdbp_index = block_freq_index; -} // namespace pisa +using profiling_block_optpfor_index = block_freq_index; +using profiling_block_varintg8iu_index = block_freq_index; +using profiling_block_streamvbyte_index = block_freq_index; +using profiling_block_maskedvbyte_index = block_freq_index; +using profiling_block_varintgb_index = block_freq_index; +using profiling_block_interpolative_index = block_freq_index; +using profiling_block_qmx_index = block_freq_index; +using profiling_block_simple8b_index = block_freq_index; +using profiling_block_simple16_index = block_freq_index; +using profiling_block_simdbp_index = block_freq_index; + +/// Exception thrown when an invalid index encoding is requested. See `IndexType::resolve`. +class InvalidEncoding: std::exception { + public: + explicit InvalidEncoding(std::string_view encoding) + : m_message(fmt::format("Invalid encoding: {}", encoding)) + {} + [[nodiscard]] auto what() const noexcept -> char const* override { return m_message.c_str(); } + + private: + std::string m_message; +}; + +namespace detail { + + /// This object holds an index type; this can be passed to a visitor without constructing the + /// index itself. + template + struct IndexTypeMarker { + using type = Index; + }; + + /// Variant type listing all supported index types. + /// + /// NOTE: To support a new index type in the tools, it must be added here. + using IndexType = std::variant< + IndexTypeMarker, + IndexTypeMarker, + IndexTypeMarker, + IndexTypeMarker, + IndexTypeMarker, + IndexTypeMarker, + IndexTypeMarker, + IndexTypeMarker, + IndexTypeMarker, + IndexTypeMarker, + IndexTypeMarker, + IndexTypeMarker, + IndexTypeMarker, + IndexTypeMarker, + IndexTypeMarker, + IndexTypeMarker, + IndexTypeMarker, + IndexTypeMarker, + IndexTypeMarker, + IndexTypeMarker, + IndexTypeMarker, + IndexTypeMarker, + IndexTypeMarker, + IndexTypeMarker>; + +} // namespace detail + +/// Number of supported index types. +constexpr std::size_t NUM_INDEX_TYPES = variant_size_v; + +namespace detail { + + /// Returns the string representation of the index type `I`. + template + [[nodiscard]] constexpr auto index_name() noexcept -> std::string_view + { + if constexpr (std::is_same_v) { + return "ef"; + } else if constexpr (std::is_same_v) { + return "single"; + } else if constexpr (std::is_same_v) { + return "pefuniform"; + } else if constexpr (std::is_same_v) { + return "pefopt"; + } else if constexpr (std::is_same_v) { + return "block_optpfor"; + } else if constexpr (std::is_same_v) { + return "block_varintg8iu"; + } else if constexpr (std::is_same_v) { + return "block_streamvbyte"; + } else if constexpr (std::is_same_v) { + return "block_maskedvbyte"; + } else if constexpr (std::is_same_v) { + return "block_interpolative"; + } else if constexpr (std::is_same_v) { + return "block_qmx"; + } else if constexpr (std::is_same_v) { + return "block_varintgb"; + } else if constexpr (std::is_same_v) { + return "block_simple8b"; + } else if constexpr (std::is_same_v) { + return "block_simple16"; + } else if constexpr (std::is_same_v) { + return "block_simdbp"; + } else if constexpr (std::is_same_v) { + return "block_optpfor"; + } else if constexpr (std::is_same_v) { + return "block_varintg8iu"; + } else if constexpr (std::is_same_v) { + return "block_streamvbyte"; + } else if constexpr (std::is_same_v) { + return "block_maskedvbyte"; + } else if constexpr (std::is_same_v) { + return "block_interpolative"; + } else if constexpr (std::is_same_v) { + return "block_qmx"; + } else if constexpr (std::is_same_v) { + return "block_varintgb"; + } else if constexpr (std::is_same_v) { + return "block_simple8b"; + } else if constexpr (std::is_same_v) { + return "block_simple16"; + } else if constexpr (std::is_same_v) { + return "block_simdbp"; + } + } + + template + struct add_profiling { + using type = Index; + }; -#define PISA_INDEX_TYPES \ - (ef)(single)(pefuniform)(pefopt)(block_optpfor)(block_varintg8iu)(block_streamvbyte)( \ - block_maskedvbyte)(block_interpolative)(block_qmx)(block_varintgb)(block_simple8b)( \ - block_simple16)(block_simdbp) -#define PISA_BLOCK_INDEX_TYPES \ - (block_optpfor)(block_varintg8iu)(block_streamvbyte)(block_maskedvbyte)(block_interpolative)( \ - block_qmx)(block_varintgb)(block_simple8b)(block_simple16)(block_simdbp) + template + struct add_profiling> { + using type = block_freq_index; + }; + + /// Attempts to match the given encoding name to the index type at position `N` in the index + /// type variant. + /// + /// The function passed in the argument assigns the resolved type object if it is matched with + /// the encoding. If the match succeeds, this function returns `true`, otherwise, it returns + /// `false`. + template + auto resolve_index_n(std::string_view encoding, Fn&& fn) -> bool + { + using T = typename std::variant_alternative_t::type; + + if (encoding == detail::index_name()) { + if constexpr (Profile) { + using P = typename add_profiling::type; + fn(IndexType(IndexTypeMarker

{})); + } else { + fn(IndexType(IndexTypeMarker{})); + } + return true; + } + return false; + } + + /// Returns the type object of the index with the given encoding, or throws `InvalidEncoding` if + /// the given encoding name does not match any index type. + template + auto resolve_index_type(std::string_view encoding, std::integer_sequence) + -> IndexType + { + IndexType index_type{}; + auto update = [&](IndexType i) { index_type = i; }; + bool loaded = (resolve_index_n(encoding, update) || ...); + if (!loaded) { + throw InvalidEncoding(encoding); + } + return index_type; + } + + /// Returns the type object of the index with the given encoding, or throws `InvalidEncoding` if + /// the given encoding name does not match any index type. + template + [[nodiscard]] auto resolve_index_type(std::string_view encoding) -> IndexType + { + std::string full_index_type = fmt::format("{}_index", encoding); + constexpr auto num_types = std::variant_size_v; + return detail::resolve_index_type(encoding, std::make_index_sequence{}); + } + + template + constexpr void push_encoding(std::array& encodings) + { + encodings[I] = + detail::index_name::type>(); + } + + template + [[nodiscard]] constexpr auto encodings(std::integer_sequence) + -> std::array + { + std::array encodings; + (push_encoding(encodings), ...); + return encodings; + } + +}; // namespace detail + +/// Returns an array of the names of all supported index types. +[[nodiscard]] inline constexpr auto encodings() -> std::array +{ + return detail::encodings(std::make_index_sequence>{}); +} + +/// Represents an index type. +/// +/// This type's objects are used execute monomorphized code on an index of a particular type. +/// See `execute()` and `load_and_execute()` for more details. +class IndexType { + detail::IndexType m_type_marker; + + explicit IndexType(detail::IndexType type_marker) : m_type_marker(type_marker) {} + + public: + /// Resolves the index type beased on its string representation. + static auto resolve(std::string_view encoding) -> IndexType + { + return IndexType(detail::resolve_index_type(encoding)); + } + + /// Same as `resolve` but use a profiling index. + static auto resolve_profiling(std::string_view encoding) -> IndexType + { + return IndexType(detail::resolve_index_type(encoding)); + } + + /// Executes a templated `fn`, passing to it an index type marker that can be used to access the + /// index type. + /// + /// # Example + /// + /// ``` + /// IndexType::resolve("block_simdbp").execute([](auto marker) { + /// using index_type = typename decltype(marker)::type; + /// }) + /// ``` + template + void execute(Fn fn) + { + std::visit(fn, m_type_marker); + } + + /// Loads an index from the given path, and executes a templated `fn`, passing to it the loaded + /// index. + /// + /// # Example + /// + /// ``` + /// IndexType::resolve("block_simdbp").load_and_execute(index_path, [](auto&& index) { + /// // use `index`, which ehere is of type `block_simdbp_index` + /// }) + /// ``` + template + void load_and_execute(std::string const& index_path, Fn fn) + { + std::visit( + [&](auto marker) { + using I = typename decltype(marker)::type; + fn(I(MemorySource::mapped_file(index_path))); + }, + m_type_marker); + } +}; + +} // namespace pisa diff --git a/include/pisa/mappable/mappable_vector.hpp b/include/pisa/mappable/mappable_vector.hpp index 5b4c6a5a5..4fff7b7e3 100644 --- a/include/pisa/mappable/mappable_vector.hpp +++ b/include/pisa/mappable/mappable_vector.hpp @@ -21,8 +21,11 @@ namespace pisa { namespace mapper { using deleter_t = boost::function; - template // T must be a POD + template class mappable_vector { + static_assert(std::is_standard_layout_v, "T must be a POD"); + static_assert(std::is_trivial_v, "T must be a POD"); + public: using value_type = T; using iterator = const T*; @@ -30,7 +33,7 @@ namespace pisa { namespace mapper { mappable_vector() : m_data(0), m_size(0), m_deleter() {} mappable_vector(mappable_vector const&) = delete; - mappable_vector(mappable_vector&&) = delete; + mappable_vector(mappable_vector&&) = default; mappable_vector& operator=(mappable_vector const&) = delete; mappable_vector& operator=(mappable_vector&&) = delete; diff --git a/include/pisa/wand_data.hpp b/include/pisa/wand_data.hpp index 8af044f2a..d8a5bb9d4 100644 --- a/include/pisa/wand_data.hpp +++ b/include/pisa/wand_data.hpp @@ -30,6 +30,12 @@ class wand_data { using wand_data_enumerator = typename block_wand_type::enumerator; wand_data() = default; + wand_data(wand_data&&) = default; + wand_data(wand_data const&) = delete; + wand_data& operator=(wand_data&&) noexcept = delete; + wand_data& operator=(wand_data const&) = delete; + ~wand_data() = default; + explicit wand_data(MemorySource source) : m_source(std::move(source)) { mapper::map(*this, m_source.data(), mapper::map_flags::warmup); diff --git a/include/pisa/wand_data_raw.hpp b/include/pisa/wand_data_raw.hpp index 36db5f720..d9210aee7 100644 --- a/include/pisa/wand_data_raw.hpp +++ b/include/pisa/wand_data_raw.hpp @@ -16,6 +16,11 @@ namespace pisa { class wand_data_raw { public: wand_data_raw() = default; + wand_data_raw(wand_data_raw&&) = default; + wand_data_raw(wand_data_raw const&) = delete; + wand_data_raw& operator=(wand_data_raw&&) noexcept = delete; + wand_data_raw& operator=(wand_data_raw const&) = delete; + ~wand_data_raw() = default; class builder { public: diff --git a/test/test_index_types.cpp b/test/test_index_types.cpp new file mode 100644 index 000000000..329a3ad42 --- /dev/null +++ b/test/test_index_types.cpp @@ -0,0 +1,22 @@ +#define CATCH_CONFIG_MAIN +#include "catch2/catch.hpp" + +#include "pisa/index_types.hpp" + +TEST_CASE("freq_index", "[index][unit]") +{ + REQUIRE(pisa::detail::index_name() == "ef"); + REQUIRE(pisa::detail::index_name() == "single"); + REQUIRE(pisa::detail::index_name() == "pefuniform"); + REQUIRE(pisa::detail::index_name() == "pefopt"); + REQUIRE(pisa::detail::index_name() == "block_optpfor"); + REQUIRE(pisa::detail::index_name() == "block_varintg8iu"); + REQUIRE(pisa::detail::index_name() == "block_streamvbyte"); + REQUIRE(pisa::detail::index_name() == "block_maskedvbyte"); + REQUIRE(pisa::detail::index_name() == "block_interpolative"); + REQUIRE(pisa::detail::index_name() == "block_qmx"); + REQUIRE(pisa::detail::index_name() == "block_varintgb"); + REQUIRE(pisa::detail::index_name() == "block_simple8b"); + REQUIRE(pisa::detail::index_name() == "block_simple16"); + REQUIRE(pisa::detail::index_name() == "block_simdbp"); +} diff --git a/tools/app.hpp b/tools/app.hpp index fb559fd6d..81a61e18a 100644 --- a/tools/app.hpp +++ b/tools/app.hpp @@ -11,6 +11,7 @@ #include #include +#include "index_types.hpp" #include "io.hpp" #include "query/queries.hpp" #include "scorer/scorer.hpp" @@ -26,6 +27,12 @@ namespace arg { explicit Encoding(CLI::App* app) { app->add_option("-e,--encoding", m_encoding, "Index encoding")->required(); + app->add_flag_callback("--list-encodings", []() { + for (auto&& encoding: encodings()) { + std::cout << encoding << '\n'; + } + std::exit(0); + }); } [[nodiscard]] auto index_encoding() const -> std::string const& { return m_encoding; } @@ -141,6 +148,21 @@ namespace arg { explicit Algorithm(CLI::App* app) { app->add_option("-a,--algorithm", m_algorithm, "Query processing algorithm")->required(); + app->add_flag_callback("--list-algorithms", []() { + std::cout << "and_query\n"; + std::cout << "block_max_maxscore_query\n"; + std::cout << "block_max_ranked_and_query\n"; + std::cout << "block_max_wand_query\n"; + std::cout << "maxscore_query\n"; + std::cout << "or_query\n"; + std::cout << "range_query\n"; + std::cout << "range_taat_query\n"; + std::cout << "ranked_and_query\n"; + std::cout << "ranked_or_query\n"; + std::cout << "ranked_or_taat_query\n"; + std::cout << "wand_query\n"; + std::exit(0); + }); } [[nodiscard]] auto algorithm() const -> std::string const& { return m_algorithm; } @@ -166,6 +188,19 @@ namespace arg { app->add_option("--bm25-b", args.m_params.bm25_b, "BM25 b parameter.")->needs(scorer); app->add_option("--pl2-c", args.m_params.pl2_c, "PL2 c parameter.")->needs(scorer); app->add_option("--qld-mu", args.m_params.qld_mu, "QLD mu parameter.")->needs(scorer); + app->add_flag_callback("--list-scorers", []() { + std::cout << "bm25" + << "\tOkapi BM25\n"; + std::cout << "qld" + << "\tQuery Likelihood with Dirichlet Smoothing\n"; + std::cout << "pl2" + << "\tPL2 probabilistic model\n"; + std::cout << "dph" + << "\tDPH model\n"; + std::cout << "quantized" + << "\tQuantized scores are read directly from the index\n"; + std::exit(0); + }); return scorer; } diff --git a/tools/compute_intersection.cpp b/tools/compute_intersection.cpp index 14ab80307..fff58d10d 100644 --- a/tools/compute_intersection.cpp +++ b/tools/compute_intersection.cpp @@ -21,18 +21,14 @@ using namespace pisa; using pisa::intersection::IntersectionType; using pisa::intersection::Mask; -template +template void intersect( - std::string const& index_filename, + Index&& index, std::optional const& wand_data_filename, QueryRange&& queries, IntersectionType intersection_type, std::optional max_term_count = std::nullopt) { - IndexType index; - mio::mmap_source m(index_filename.c_str()); - mapper::map(index, m); - WandType wdata; mio::mmap_source md; @@ -117,24 +113,13 @@ int main(int argc, const char** argv) IntersectionType intersection_type = combinations ? IntersectionType::Combinations : IntersectionType::Query; - /**/ - if (false) { -#define LOOP_BODY(R, DATA, T) \ - } \ - else if (app.index_encoding() == BOOST_PP_STRINGIZE(T)) \ - { \ - intersect( \ - app.index_filename(), \ - app.wand_data_path(), \ - filtered_queries, \ - intersection_type, \ - max_term_count); \ - /**/ - - BOOST_PP_SEQ_FOR_EACH(LOOP_BODY, _, PISA_INDEX_TYPES); -#undef LOOP_BODY - - } else { - spdlog::error("Unknown type {}", app.index_encoding()); + try { + IndexType::resolve(app.index_encoding()).load_and_execute(app.index_filename(), [&](auto&& index) { + intersect( + index, app.wand_data_path(), filtered_queries, intersection_type, max_term_count); + }); + } catch (std::exception const& err) { + spdlog::error("{}", err.what()); + return 1; } } diff --git a/tools/count_postings.cpp b/tools/count_postings.cpp index 7501ca7c0..da4e5df22 100644 --- a/tools/count_postings.cpp +++ b/tools/count_postings.cpp @@ -14,15 +14,14 @@ using namespace pisa; -template +template void extract( - std::string const& index_filename, + I const& index, std::vector const& queries, std::string const& separator, bool sum, bool print_qid) { - Index index(MemorySource::mapped_file(index_filename)); auto body = [&] { if (sum) { return std::function([&](auto const& query) { @@ -66,20 +65,13 @@ int main(int argc, char** argv) "printed, separated by the separator defined with --sep"); CLI11_PARSE(app, argc, argv); - /**/ - if (false) { -#define LOOP_BODY(R, DATA, T) \ - } \ - else if (app.index_encoding() == BOOST_PP_STRINGIZE(T)) \ - { \ - extract( \ - app.index_filename(), app.queries(), app.separator(), sum, app.print_query_id()); \ - /**/ - BOOST_PP_SEQ_FOR_EACH(LOOP_BODY, _, PISA_INDEX_TYPES); -#undef LOOP_BODY - - } else { - spdlog::error("Unknown type {}", app.index_encoding()); + try { + IndexType::resolve(app.index_encoding()) + .load_and_execute(app.index_filename(), [&](auto const& index) { + extract(index, app.queries(), app.separator(), sum, app.print_query_id()); + }); + } catch (InvalidEncoding const& err) { + spdlog::error("{}", err.what()); } return 0; diff --git a/tools/evaluate_queries.cpp b/tools/evaluate_queries.cpp index 4aac020cb..f03effd84 100644 --- a/tools/evaluate_queries.cpp +++ b/tools/evaluate_queries.cpp @@ -30,10 +30,10 @@ using namespace pisa; using ranges::views::enumerate; -template +template void evaluate_queries( - const std::string& index_filename, - const std::string& wand_data_filename, + Index&& index, + Wdata&& wdata, const std::vector& queries, const std::optional& thresholds_filename, std::string const& type, @@ -44,9 +44,6 @@ void evaluate_queries( std::string const& run_id, std::string const& iteration) { - IndexType index(MemorySource::mapped_file(index_filename)); - WandType const wdata(MemorySource::mapped_file(wand_data_filename)); - auto scorer = scorer::from_params(scorer_params, wdata); std::function>(Query)> query_fun; @@ -199,41 +196,35 @@ int main(int argc, const char** argv) auto iteration = "Q0"; - auto params = std::make_tuple( - app.index_filename(), - app.wand_data_path(), - app.queries(), - app.thresholds_file(), - app.index_encoding(), - app.algorithm(), - app.k(), - documents_file, - app.scorer_params(), - run_id, - iteration); - - /**/ - if (false) { // NOLINT -#define LOOP_BODY(R, DATA, T) \ - } \ - else if (app.index_encoding() == BOOST_PP_STRINGIZE(T)) \ - { \ - if (app.is_wand_compressed()) { \ - if (quantized) { \ - std::apply( \ - evaluate_queries, \ - params); \ - } else { \ - std::apply(evaluate_queries, params); \ - } \ - } else { \ - std::apply(evaluate_queries, params); \ - } \ - /**/ - - BOOST_PP_SEQ_FOR_EACH(LOOP_BODY, _, PISA_INDEX_TYPES); -#undef LOOP_BODY - } else { - spdlog::error("Unknown type {}", app.index_encoding()); + try { + IndexType::resolve(app.index_encoding()).load_and_execute(app.index_filename(), [&](auto&& index) { + auto evaluate = [&](auto wdata) { + evaluate_queries( + index, + wdata, + app.queries(), + app.thresholds_file(), + app.index_encoding(), + app.algorithm(), + app.k(), + documents_file, + app.scorer_params(), + run_id, + iteration); + }; + auto wdata_source = MemorySource::mapped_file(app.wand_data_path()); + if (app.is_wand_compressed()) { + if (quantized) { + evaluate(wand_uniform_index_quantized(std::move(wdata_source))); + } else { + evaluate(wand_uniform_index(std::move(wdata_source))); + } + } else { + evaluate(wand_raw_index(std::move(wdata_source))); + } + }); + } catch (std::exception const& err) { + spdlog::error("{}", err.what()); + return 1; } } diff --git a/tools/kth_threshold.cpp b/tools/kth_threshold.cpp index 9995c51de..9ebeb0918 100644 --- a/tools/kth_threshold.cpp +++ b/tools/kth_threshold.cpp @@ -49,10 +49,10 @@ std::set parse_tuple(std::string const& line, size_t k) return term_ids_int; } -template -void kt_thresholds( - const std::string& index_filename, - const std::string& wand_data_filename, +template +void kth_thresholds( + Index&& index, + Wdata&& wdata, const std::vector& queries, std::string const& type, ScorerParams const& scorer_params, @@ -63,23 +63,8 @@ void kt_thresholds( bool all_pairs, bool all_triples) { - IndexType index; - mio::mmap_source m(index_filename.c_str()); - mapper::map(index, m); - - WandType wdata; - auto scorer = scorer::from_params(scorer_params, wdata); - mio::mmap_source md; - std::error_code error; - md.map(wand_data_filename, error); - if (error) { - spdlog::error("error mapping file: {}, exiting...", error.message()); - std::abort(); - } - mapper::map(wdata, md, mapper::map_flags::warmup); - using Pair = std::set; std::unordered_set> pairs_set; @@ -201,27 +186,35 @@ int main(int argc, const char** argv) all_pairs, all_triples); - /**/ - if (false) { -#define LOOP_BODY(R, DATA, T) \ - } \ - else if (app.index_encoding() == BOOST_PP_STRINGIZE(T)) \ - { \ - if (app.is_wand_compressed()) { \ - if (quantized) { \ - std::apply( \ - kt_thresholds, params); \ - } else { \ - std::apply(kt_thresholds, params); \ - } \ - } else { \ - std::apply(kt_thresholds, params); \ - } - /**/ - BOOST_PP_SEQ_FOR_EACH(LOOP_BODY, _, PISA_INDEX_TYPES); -#undef LOOP_BODY - - } else { - spdlog::error("Unknown type {}", app.index_encoding()); + try { + IndexType::resolve(app.index_encoding()).load_and_execute(app.index_filename(), [&](auto&& index) { + auto thresholds = [&](auto wdata) { + kth_thresholds( + index, + wdata, + app.queries(), + app.index_encoding(), + app.scorer_params(), + app.k(), + quantized, + pairs_filename, + triples_filename, + all_pairs, + all_triples); + }; + auto wdata_source = MemorySource::mapped_file(app.wand_data_path()); + if (app.is_wand_compressed()) { + if (quantized) { + thresholds(wand_uniform_index_quantized(std::move(wdata_source))); + } else { + thresholds(wand_uniform_index(std::move(wdata_source))); + } + } else { + thresholds(wand_raw_index(std::move(wdata_source))); + } + }); + } catch (std::exception const& err) { + spdlog::error("{}", err.what()); + return 1; } } diff --git a/tools/profile_queries.cpp b/tools/profile_queries.cpp index f662b7738..85cc773e2 100644 --- a/tools/profile_queries.cpp +++ b/tools/profile_queries.cpp @@ -50,20 +50,9 @@ void op_profile(QueryOperator const& query_op, std::vector const& queries } } -template -struct add_profiling { - using type = IndexType; -}; - -template -struct add_profiling> { - using type = block_freq_index; -}; - -template +template void profile( - const std::string index_filename, - + Index&& index, const std::optional& wand_data_filename, std::vector const& queries, std::string const& type, @@ -71,12 +60,7 @@ void profile( { using namespace pisa; - typename add_profiling::type index; using WandType = wand_data; - spdlog::info("Loading index from {}", index_filename); - mio::mmap_source m(index_filename); - mapper::map(index, m); - WandType const wdata = [&] { if (wand_data_filename) { return WandType(MemorySource::mapped_file(*wand_data_filename)); @@ -97,19 +81,13 @@ void profile( if (t == "and") { query_fun = [&](Query query) { and_query and_q; - return and_q( - make_cursors::type>(index, query), - index.num_docs()) - .size(); + return and_q(make_cursors(index, query), index.num_docs()).size(); }; } else if (t == "ranked_and" && wand_data_filename) { query_fun = [&](Query query) { topk_queue topk(10); ranked_and_query ranked_and_q(topk); - ranked_and_q( - make_scored_cursors::type>( - index, *scorer, query), - index.num_docs()); + ranked_and_q(make_scored_cursors(index, *scorer, query), index.num_docs()); topk.finalize(); return topk.topk().size(); }; @@ -117,10 +95,7 @@ void profile( query_fun = [&](Query query) { topk_queue topk(10); wand_query wand_q(topk); - wand_q( - make_max_scored_cursors::type, WandType>( - index, wdata, *scorer, query), - index.num_docs()); + wand_q(make_max_scored_cursors(index, wdata, *scorer, query), index.num_docs()); topk.finalize(); return topk.topk().size(); }; @@ -128,10 +103,7 @@ void profile( query_fun = [&](Query query) { topk_queue topk(10); maxscore_query maxscore_q(topk); - maxscore_q( - make_max_scored_cursors::type, WandType>( - index, wdata, *scorer, query), - index.num_docs()); + maxscore_q(make_max_scored_cursors(index, wdata, *scorer, query), index.num_docs()); topk.finalize(); return topk.topk().size(); }; @@ -176,18 +148,13 @@ int main(int argc, const char** argv) } } - if (false) { -#define LOOP_BODY(R, DATA, T) \ - } \ - else if (type == BOOST_PP_STRINGIZE(T)) \ - { \ - profile( \ - index_filename, wand_data_filename, queries, type, query_type); \ - /**/ - - BOOST_PP_SEQ_FOR_EACH(LOOP_BODY, _, PISA_INDEX_TYPES); -#undef LOOP_BODY - } else { - spdlog::error("Unknown type {}", type); + try { + spdlog::info("Loading index from {}", index_filename); + IndexType::resolve_profiling(type).load_and_execute(index_filename, [&](auto&& index) { + profile(index, wand_data_filename, queries, type, query_type); + }); + } catch (std::exception const& err) { + spdlog::error("{}", err.what()); + return 1; } } diff --git a/tools/queries.cpp b/tools/queries.cpp index f7344279d..b33752ae3 100644 --- a/tools/queries.cpp +++ b/tools/queries.cpp @@ -114,10 +114,10 @@ void op_perftest( } } -template +template void perftest( - const std::string& index_filename, - const std::optional& wand_data_filename, + Index&& index, + std::optional const& wdata, const std::vector& queries, const std::optional& thresholds_filename, std::string const& type, @@ -127,9 +127,6 @@ void perftest( bool extract, bool safe) { - spdlog::info("Loading index from {}", index_filename); - IndexType index(MemorySource::mapped_file(index_filename)); - spdlog::info("Warming up posting lists"); std::unordered_set warmed_up; for (auto const& q: queries) { @@ -141,13 +138,6 @@ void perftest( } } - WandType const wdata = [&] { - if (wand_data_filename) { - return WandType(MemorySource::mapped_file(*wand_data_filename)); - } - return WandType{}; - }(); - std::vector thresholds(queries.size(), 0.0); if (thresholds_filename) { std::string t; @@ -162,7 +152,12 @@ void perftest( } } - auto scorer = scorer::from_params(scorer_params, wdata); + auto scorer = [&]() -> decltype(scorer::from_params(scorer_params, *wdata)) { + if (wdata) { + return scorer::from_params(scorer_params, *wdata); + } + return nullptr; + }(); spdlog::info("Performing {} queries", type); spdlog::info("K: {}", k); @@ -172,7 +167,7 @@ void perftest( for (auto&& t: query_types) { spdlog::info("Query type: {}", t); - std::function query_fun; + std::function query_fun; // = [](Query, Threshold) {}; if (t == "and") { query_fun = [&](Query query, Threshold) { and_query and_q; @@ -188,36 +183,36 @@ void perftest( or_query or_q; return or_q(make_cursors(index, query), index.num_docs()); }; - } else if (t == "wand" && wand_data_filename) { + } else if (t == "wand" && wdata) { query_fun = [&](Query query, Threshold t) { topk_queue topk(k); topk.set_threshold(t); wand_query wand_q(topk); - wand_q(make_max_scored_cursors(index, wdata, *scorer, query), index.num_docs()); + wand_q(make_max_scored_cursors(index, *wdata, *scorer, query), index.num_docs()); topk.finalize(); return topk.topk().size(); }; - } else if (t == "block_max_wand" && wand_data_filename) { + } else if (t == "block_max_wand" && wdata) { query_fun = [&](Query query, Threshold t) { topk_queue topk(k); topk.set_threshold(t); block_max_wand_query block_max_wand_q(topk); block_max_wand_q( - make_block_max_scored_cursors(index, wdata, *scorer, query), index.num_docs()); + make_block_max_scored_cursors(index, *wdata, *scorer, query), index.num_docs()); topk.finalize(); return topk.topk().size(); }; - } else if (t == "block_max_maxscore" && wand_data_filename) { + } else if (t == "block_max_maxscore" && wdata) { query_fun = [&](Query query, Threshold t) { topk_queue topk(k); topk.set_threshold(t); block_max_maxscore_query block_max_maxscore_q(topk); block_max_maxscore_q( - make_block_max_scored_cursors(index, wdata, *scorer, query), index.num_docs()); + make_block_max_scored_cursors(index, *wdata, *scorer, query), index.num_docs()); topk.finalize(); return topk.topk().size(); }; - } else if (t == "ranked_and" && wand_data_filename) { + } else if (t == "ranked_and" && wdata) { query_fun = [&](Query query, Threshold t) { topk_queue topk(k); topk.set_threshold(t); @@ -226,17 +221,17 @@ void perftest( topk.finalize(); return topk.topk().size(); }; - } else if (t == "block_max_ranked_and" && wand_data_filename) { + } else if (t == "block_max_ranked_and" && wdata) { query_fun = [&](Query query, Threshold t) { topk_queue topk(k); topk.set_threshold(t); block_max_ranked_and_query block_max_ranked_and_q(topk); block_max_ranked_and_q( - make_block_max_scored_cursors(index, wdata, *scorer, query), index.num_docs()); + make_block_max_scored_cursors(index, *wdata, *scorer, query), index.num_docs()); topk.finalize(); return topk.topk().size(); }; - } else if (t == "ranked_or" && wand_data_filename) { + } else if (t == "ranked_or" && wdata) { query_fun = [&](Query query, Threshold t) { topk_queue topk(k); topk.set_threshold(t); @@ -245,16 +240,16 @@ void perftest( topk.finalize(); return topk.topk().size(); }; - } else if (t == "maxscore" && wand_data_filename) { + } else if (t == "maxscore" && wdata) { query_fun = [&](Query query, Threshold t) { topk_queue topk(k); topk.set_threshold(t); maxscore_query maxscore_q(topk); - maxscore_q(make_max_scored_cursors(index, wdata, *scorer, query), index.num_docs()); + maxscore_q(make_max_scored_cursors(index, *wdata, *scorer, query), index.num_docs()); topk.finalize(); return topk.topk().size(); }; - } else if (t == "ranked_or_taat" && wand_data_filename) { + } else if (t == "ranked_or_taat" && wdata) { Simple_Accumulator accumulator(index.num_docs()); topk_queue topk(k); ranked_or_taat_query ranked_or_taat_q(topk); @@ -265,7 +260,7 @@ void perftest( topk.finalize(); return topk.topk().size(); }; - } else if (t == "ranked_or_taat_lazy" && wand_data_filename) { + } else if (t == "ranked_or_taat_lazy" && wdata) { Lazy_Accumulator<4> accumulator(index.num_docs()); topk_queue topk(k); ranked_or_taat_query ranked_or_taat_q(topk); @@ -292,6 +287,15 @@ using wand_raw_index = wand_data; using wand_uniform_index = wand_data>; using wand_uniform_index_quantized = wand_data>; +template +auto load_wdata(std::optional const& wand_data_filename) -> std::optional +{ + if (wand_data_filename) { + return Wdata(MemorySource::mapped_file(*wand_data_filename)); + } + return std::nullopt; +} + int main(int argc, const char** argv) { bool extract = false; @@ -322,37 +326,34 @@ int main(int argc, const char** argv) std::cout << "qid\tusec\n"; } - auto params = std::make_tuple( - app.index_filename(), - app.wand_data_path(), - app.queries(), - app.thresholds_file(), - app.index_encoding(), - app.algorithm(), - app.k(), - app.scorer_params(), - extract, - safe); - /**/ - if (false) { -#define LOOP_BODY(R, DATA, T) \ - } \ - else if (app.index_encoding() == BOOST_PP_STRINGIZE(T)) \ - { \ - if (app.is_wand_compressed()) { \ - if (quantized) { \ - std::apply(perftest, params); \ - } else { \ - std::apply(perftest, params); \ - } \ - } else { \ - std::apply(perftest, params); \ - } - /**/ - BOOST_PP_SEQ_FOR_EACH(LOOP_BODY, _, PISA_INDEX_TYPES); -#undef LOOP_BODY - - } else { - spdlog::error("Unknown type {}", app.index_encoding()); + try { + spdlog::info("Loading index from {}", app.index_filename()); + IndexType::resolve(app.index_encoding()).load_and_execute(app.index_filename(), [&](auto&& index) { + auto perf = [&](auto wdata) { + perftest( + index, + wdata, + app.queries(), + app.thresholds_file(), + app.index_encoding(), + app.algorithm(), + app.k(), + app.scorer_params(), + extract, + safe); + }; + if (app.is_wand_compressed()) { + if (quantized) { + perf(load_wdata(app.wand_data_path())); + } else { + perf(load_wdata(app.wand_data_path())); + } + } else { + perf(load_wdata(app.wand_data_path())); + } + }); + } catch (std::exception const& err) { + spdlog::error("{}", err.what()); + return 1; } } diff --git a/tools/selective_queries.cpp b/tools/selective_queries.cpp index 112f9e5a2..59afc8389 100644 --- a/tools/selective_queries.cpp +++ b/tools/selective_queries.cpp @@ -14,15 +14,9 @@ using namespace pisa; -template -void selective_queries( - const std::string& index_filename, std::string const& encoding, std::vector const& queries) +template +void selective_queries(Index&& index, std::string const& encoding, std::vector const& queries) { - IndexType index; - spdlog::info("Loading index from {}", index_filename); - mio::mmap_source m(index_filename.c_str()); - mapper::map(index, m, mapper::map_flags::warmup); - spdlog::info("Performing {} queries", encoding); using boost::adaptors::transformed; @@ -46,18 +40,13 @@ int main(int argc, const char** argv) "Filters selective queries for a given index."}; CLI11_PARSE(app, argc, argv); - if (false) { -#define LOOP_BODY(R, DATA, T) \ - } \ - else if (app.index_encoding() == BOOST_PP_STRINGIZE(T)) \ - { \ - selective_queries( \ - app.index_filename(), app.index_encoding(), app.queries()); - /**/ - - BOOST_PP_SEQ_FOR_EACH(LOOP_BODY, _, PISA_INDEX_TYPES); -#undef LOOP_BODY - } else { - spdlog::error("Unknown encoding {}", app.index_encoding()); + try { + IndexType::resolve_profiling(app.index_encoding()) + .load_and_execute(app.index_filename(), [&](auto&& index) { + selective_queries(index, app.index_encoding(), app.queries()); + }); + } catch (std::exception const& err) { + spdlog::error("{}", err.what()); + return 1; } } diff --git a/tools/thresholds.cpp b/tools/thresholds.cpp index 82a5fda31..6f32e5945 100644 --- a/tools/thresholds.cpp +++ b/tools/thresholds.cpp @@ -24,19 +24,16 @@ using namespace pisa; -template +template void thresholds( - const std::string& index_filename, - const std::string& wand_data_filename, + Index&& index, + Wdata&& wdata, const std::vector& queries, std::string const& type, ScorerParams const& scorer_params, uint64_t k, bool quantized) { - IndexType index(MemorySource::mapped_file(index_filename)); - WandType const wdata(MemorySource::mapped_file(wand_data_filename)); - auto scorer = scorer::from_params(scorer_params, wdata); topk_queue topk(k); @@ -83,27 +80,32 @@ int main(int argc, const char** argv) app.k(), quantized); - /**/ - if (false) { -#define LOOP_BODY(R, DATA, T) \ - } \ - else if (app.index_encoding() == BOOST_PP_STRINGIZE(T)) \ - { \ - if (app.is_wand_compressed()) { \ - if (quantized) { \ - std::apply( \ - thresholds, params); \ - } else { \ - std::apply(thresholds, params); \ - } \ - } else { \ - std::apply(thresholds, params); \ - } - /**/ - BOOST_PP_SEQ_FOR_EACH(LOOP_BODY, _, PISA_INDEX_TYPES); -#undef LOOP_BODY - - } else { - spdlog::error("Unknown type {}", app.index_encoding()); + try { + IndexType::resolve(app.index_encoding()).load_and_execute(app.index_filename(), [&](auto&& index) { + auto th = [&](auto wdata) { + thresholds( + index, + wdata, + app.queries(), + app.index_encoding(), + app.scorer_params(), + app.k(), + quantized); + }; + auto wdata_source = MemorySource::mapped_file(app.wand_data_path()); + if (app.is_wand_compressed()) { + if (quantized) { + th(wand_uniform_index_quantized(std::move(wdata_source))); + } else { + th(wand_uniform_index(std::move(wdata_source))); + } + } else { + th(wand_raw_index(std::move(wdata_source))); + } + }); + } catch (std::exception const& err) { + spdlog::error("{}", err.what()); + return 1; } + return 0; }