diff --git a/include/valik/search/local_prefilter.hpp b/include/valik/search/local_prefilter.hpp index b4599f58..a09eb1c5 100644 --- a/include/valik/search/local_prefilter.hpp +++ b/include/valik/search/local_prefilter.hpp @@ -13,6 +13,40 @@ namespace valik { +/** + * @brief Function that samples patterns on a query. + * + * @param read_len Length of query. + * @param pattern_size Length of pattern. + * @param query_every Every nth potential match is considered. + * @param callback Functor that corrects the threshold based on matching k-mer counts. + * @return Lower quartile of threshold correction. + */ +template +constexpr double sample_begin_positions(size_t const read_len, uint64_t const pattern_size, uint8_t const query_every, functor_t && callback) +{ + assert(read_len >= pattern_size); + + size_t first_pos{pattern_size}; + if (read_len < query_every + pattern_size) + first_pos = 0; // start from beginning when short query + + size_t corrected_pattern_count{0u}; + double total_correction{0}; + for (size_t pos = first_pos; pos <= read_len - pattern_size; pos = pos + query_every * pattern_size) + { + auto correction = callback(pos); + if (correction > 0) + { + corrected_pattern_count++; + total_correction += correction; + } + } + + return 1 + total_correction / (double) std::max(corrected_pattern_count, (size_t) 1); +} + + /** * @brief Function that finds the begin positions of all pattern of a query. * @@ -33,10 +67,10 @@ constexpr void pattern_begin_positions(size_t const read_len, uint64_t const pat assert(read_len >= pattern_size); size_t last_begin{0u}; - for (size_t i = 0; i <= read_len - pattern_size; i = i + query_every) + for (size_t pos = 0; pos <= read_len - pattern_size; pos = pos + query_every) { - callback(i); - last_begin = i; + callback(pos); + last_begin = pos; } if (last_begin < read_len - pattern_size) @@ -54,6 +88,11 @@ struct pattern_bounds size_t begin_position; size_t end_position; size_t threshold; + + size_t minimiser_count() const + { + return end_position - begin_position; + } }; /** @@ -89,35 +128,96 @@ pattern_bounds make_pattern_bounds(size_t const & begin, assert(end_it != window_span_begin.begin()); pattern.end_position = end_it - window_span_begin.begin(); - size_t const minimiser_count = pattern.end_position - pattern.begin_position; - - pattern.threshold = thresholder.get(minimiser_count); + pattern.threshold = thresholder.get(pattern.minimiser_count()); return pattern; } +/** + * @brief Function that for a single pattern counts matching k-mers and corrects the threshold to avoid too many spuriously matching bins. + * + * @param pattern Slice of a query record that is being considered. + * @param bin_count Number of bins in the IBF. + * @param counting_table Rows: minimisers of the query. Columns: bins of the IBF. + * @return Threshold correction that avoids too many spurious matches. + */ +template +double find_dynamic_threshold_correction(pattern_bounds const & pattern, + size_t const & bin_count, + binning_bitvector_t const & counting_table) +{ + // counting vector for the current pattern + seqan3::counting_vector total_counts(bin_count, 0); + + for (size_t i = pattern.begin_position; i < pattern.end_position; i++) + total_counts += counting_table[i]; + + std::unordered_set pattern_hits; + + bool max_threshold{false}; + uint8_t correction_count{0}; + while (true) + { + for (size_t current_bin = 0; current_bin < total_counts.size(); current_bin++) + { + auto &&count = total_counts[current_bin]; + if (count >= (pattern.threshold + correction_count)) + { + pattern_hits.insert(current_bin); + } + } + if ((pattern.threshold + correction_count) >= pattern.minimiser_count()) + max_threshold = true; + if (pattern_hits.size() < std::max((size_t) 4, (size_t) std::round(bin_count / 4.0)) || + max_threshold) + break; + else + { + pattern_hits.clear(); + // increase threshold in 10% increments or by at least 1 to find lowest threshold that is not ubiquitous + correction_count += std::max((size_t) 1, (size_t) std::round(pattern.threshold * 0.1 * correction_count)); + } + } + + return (double) correction_count / (double) pattern.threshold; +} + + /** * @brief Function that for a single pattern counts matching k-mers and returns bins that exceed the threshold. * * @param pattern Slice of a query record that is being considered. + * @param correction Threshold correction determined from a sample of patterns. * @param bin_count Number of bins in the IBF. * @param counting_table Rows: minimisers of the query. Columns: bins of the IBF. * @param sequence_hits Bins that likely contain a match for the pattern (IN-OUT parameter). */ template void find_pattern_bins(pattern_bounds const & pattern, - size_t const & bin_count, - binning_bitvector_t const & counting_table, - std::unordered_set & sequence_hits) + double const & correction_coef, + size_t const & bin_count, + binning_bitvector_t const & counting_table, + std::unordered_set & sequence_hits) { // counting vector for the current pattern seqan3::counting_vector total_counts(bin_count, 0); for (size_t i = pattern.begin_position; i < pattern.end_position; i++) total_counts += counting_table[i]; + for (size_t current_bin = 0; current_bin < total_counts.size(); current_bin++) { auto &&count = total_counts[current_bin]; - if (count >= pattern.threshold) + /* + if (current_bin == 0) + { + if (std::round(pattern.threshold * correction_coef) > pattern.threshold) + { + seqan3::debug_stream << "Threshold was " << pattern.threshold << '\n'; + seqan3::debug_stream << "New threshold " << std::to_string((size_t) std::round(pattern.threshold * correction_coef)) << '\n'; + } + } + */ + if (count >= (pattern.threshold * correction_coef)) { // the result is a union of results from all patterns of a read sequence_hits.insert(current_bin); @@ -199,10 +299,22 @@ void local_prefilter( minimiser.clear(); std::unordered_set sequence_hits{}; + double threshold_correction{1}; + if (!arguments.static_threshold) + { + threshold_correction = sample_begin_positions(seq.size(), arguments.pattern_size, arguments.query_every, [&](size_t const begin) -> double + { + pattern_bounds const pattern = make_pattern_bounds(begin, arguments, window_span_begin, thresholder); + return find_dynamic_threshold_correction(pattern, bin_count, counting_table); + }); + } + + if (threshold_correction > 1.0000001) + seqan3::debug_stream << "Correct threshold by " << threshold_correction << '\n'; pattern_begin_positions(seq.size(), arguments.pattern_size, arguments.query_every, [&](size_t const begin) { pattern_bounds const pattern = make_pattern_bounds(begin, arguments, window_span_begin, thresholder); - find_pattern_bins(pattern, bin_count, counting_table, sequence_hits); + find_pattern_bins(pattern, threshold_correction, bin_count, counting_table, sequence_hits); }); result_cb(record, sequence_hits); diff --git a/include/valik/search/search_local.hpp b/include/valik/search/search_local.hpp index 60959b9c..9d27a5ce 100644 --- a/include/valik/search/search_local.hpp +++ b/include/valik/search/search_local.hpp @@ -132,7 +132,10 @@ bool search_local(search_arguments & arguments, search_time_statistics & time_st std::cout.precision(3); std::cout << "\n-----------Search parameters-----------\n"; - std::cout << "kmer size " << std::to_string(arguments.shape_size) << '\n'; + if (arguments.shape_size == arguments.shape_weight) + std::cout << "kmer size " << std::to_string(arguments.shape_size) << '\n'; + else + std::cout << "kmer shape " << arguments.shape.to_string() << '\n'; std::cout << "window size " << std::to_string(arguments.window_size) << '\n'; switch (arguments.search_type) { diff --git a/include/valik/shared.hpp b/include/valik/shared.hpp index da173659..e8323548 100644 --- a/include/valik/shared.hpp +++ b/include/valik/shared.hpp @@ -188,6 +188,7 @@ struct search_arguments final : public minimiser_threshold_arguments, search_pro bool keep_best_repeats{false}; double best_bin_entropy_cutoff{0.25}; bool keep_all_repeats{false}; + bool static_threshold{false}; bool stellar_only{false}; size_t cart_max_capacity{1000}; diff --git a/src/argument_parsing/search.cpp b/src/argument_parsing/search.cpp index d3418bb5..209cd3ba 100644 --- a/src/argument_parsing/search.cpp +++ b/src/argument_parsing/search.cpp @@ -106,6 +106,11 @@ void init_search_parser(sharg::parser & parser, search_arguments & arguments) .long_id = "keep-all-repeats", .description = "Do not filter out query matches from repeat regions. This may significantly increase the runtime.", .advanced = true}); + parser.add_flag(arguments.static_threshold, + sharg::config{.short_id = '\0', + .long_id = "static-threshold", + .description = "Do not correct threshold to avoid many spuriously matching bins.", + .advanced = true}); parser.add_option(arguments.seg_count_in, sharg::config{.short_id = 'n', .long_id = "seg-count", @@ -346,7 +351,7 @@ void run_search(sharg::parser & parser) { arguments.search_type = search_kind::LEMMA; if (arguments.threshold < lemma_thresh) - std::cerr << "[Warning] chosen threshold is less than the k-mer lemma threshold. Ignore this warning if this was deliberate."; + std::cerr << "[Warning] The chosen threshold is less than the k-mer lemma threshold. Ignore this warning if this was deliberate."; } } if (arguments.stellar_only) diff --git a/test/cli/valik_test.cpp b/test/cli/valik_test.cpp index 8fb9ce24..31e2cb0b 100644 --- a/test/cli/valik_test.cpp +++ b/test/cli/valik_test.cpp @@ -350,7 +350,7 @@ TEST_P(valik_search_clusters, search) "--error-rate ", std::to_string(error_rate), "--index ", ibf_path(number_of_bins, window_size), "--query ", data("query.fq"), - "--threads 1", "--very-verbose", + "--threads 1", "--very-verbose", "--static-threshold", "--cart-max-capacity 3", "--max-queued-carts 10", "--without-parameter-tuning"); @@ -399,7 +399,7 @@ TEST_P(valik_search_segments, search) "--error-rate ", std::to_string(error_rate), "--index ", ibf_path(segment_overlap, number_of_bins, window_size), "--query ", data("single_query.fasta"), - "--threads 1", "--very-verbose", + "--threads 1", "--very-verbose", "--static-threshold", "--ref-meta", segment_metadata_path(segment_overlap, number_of_bins), "--cart-max-capacity 3", "--max-queued-carts 10",