2828#include " absl/container/fixed_array.h"
2929#include " absl/log/check.h"
3030#include " absl/log/log.h"
31+ #include " absl/numeric/int128.h"
3132#include " absl/types/span.h"
3233#include " xls/common/iter_util.h"
3334#include " xls/common/iterator_range.h"
@@ -516,6 +517,96 @@ std::optional<IntervalSet> MaybePerformExactCalculation(
516517 return std::move (results);
517518}
518519
520+ // How many computations are we willing to do for a variadic operation.
521+ static constexpr int64_t kMaxVariadicOperations = 1000000 ;
522+
523+ struct SafeMultiplyResult {
524+ bool overflow = false ;
525+ uint64_t result = 0 ;
526+ };
527+
528+ // TODO(allight): Replace this with safe_int_ops once it is available.
529+ SafeMultiplyResult SafeMultiply (uint64_t a, uint64_t b) {
530+ absl::uint128 big_a = a;
531+ absl::uint128 big_b = b;
532+ absl::uint128 product = big_a * big_b;
533+ if (product > std::numeric_limits<uint64_t >::max ()) {
534+ return {true , 0 };
535+ }
536+ return {false , static_cast <uint64_t >(product)};
537+ }
538+
539+ SafeMultiplyResult SafeMultiply (SafeMultiplyResult a, uint64_t b) {
540+ if (a.overflow ) {
541+ return {true , 0 };
542+ }
543+ return SafeMultiply (a.result , b);
544+ }
545+
546+ // Heuristically reduce the specificity of the input interval sets to avoid
547+ // exponential blow up. We do this by merging intervals which cause the least
548+ // loss of precision iteratively until we are below the threshold.
549+ std::vector<IntervalSet> ReduceIntervalFragmentationForOp (
550+ absl::Span<const IntervalSet> interval_sets) {
551+ // Since we do variadic ops the number of operations is proportional to the
552+ // product of the number of intervals in each interval set.
553+ std::vector<IntervalSet> results (interval_sets.begin (), interval_sets.end ());
554+
555+ int64_t max_bit_width = 0 ;
556+ int64_t total_intervals = 0 ;
557+ for (const auto & is : interval_sets) {
558+ max_bit_width = std::max (max_bit_width, is.BitCount ());
559+ total_intervals += is.NumberOfIntervals ();
560+ }
561+
562+ // NB This will iterate at most the total number of intervals in all the
563+ // interval sets since we combine one at a time. We should always break before
564+ // we hit this limit but this is a good correctness check to stop us from
565+ // going on forever.
566+ for (int64_t i = 0 ; i < total_intervals; ++i) {
567+ auto [overflow, product] = absl::c_accumulate (
568+ results, SafeMultiplyResult{false , 1 },
569+ [&](SafeMultiplyResult acc, const IntervalSet& is) {
570+ return SafeMultiply (acc,
571+ static_cast <uint64_t >(is.NumberOfIntervals ()));
572+ });
573+
574+ if (!overflow && product <= kMaxVariadicOperations ) {
575+ return results;
576+ }
577+
578+ std::optional<int64_t > best_set_idx;
579+ std::optional<Bits> min_gap_size;
580+
581+ // NB We could do this with min_element or something but this lets us keep
582+ // around the best measure too.
583+ for (int64_t i = 0 ; i < results.size (); ++i) {
584+ const auto & is = results[i];
585+ if (is.NumberOfIntervals () <= 1 ) {
586+ continue ;
587+ }
588+ for (int64_t j = 0 ; j < is.NumberOfIntervals () - 1 ; ++j) {
589+ Bits gap_size = bits_ops::Sub (is.Intervals ()[j + 1 ].LowerBound (),
590+ is.Intervals ()[j].UpperBound ());
591+ Bits extended_gap_size = bits_ops::ZeroExtend (gap_size, max_bit_width);
592+
593+ if (!min_gap_size.has_value () ||
594+ bits_ops::ULessThan (extended_gap_size, *min_gap_size)) {
595+ min_gap_size = extended_gap_size;
596+ best_set_idx = i;
597+ }
598+ }
599+ }
600+ CHECK (best_set_idx.has_value ());
601+
602+ results[*best_set_idx] = MinimizeIntervals (
603+ results[*best_set_idx], results[*best_set_idx].NumberOfIntervals () - 1 );
604+ }
605+ LOG (FATAL) << " Failed to reduce interval fragmentation. Expected to break "
606+ " before hitting "
607+ << total_intervals << " iterations" ;
608+ }
609+
519610template <typename Calculate>
520611 requires (
521612 std::is_invocable_r_v<OverflowResult, Calculate, absl::Span<Bits const >>)
@@ -528,30 +619,11 @@ IntervalSet PerformVariadicOp(Calculate calc,
528619 std::optional<IntervalSet> exact_result = MaybePerformExactCalculation (
529620 calc, behaviors, input_operands, result_bit_size);
530621 if (exact_result) {
531- // VLOG(2) << "Got exact results: " << exact_result->ToString();
532622 return *exact_result;
533623 }
534624
535- std::vector<IntervalSet> operands;
536- operands.reserve (input_operands.size ());
537-
538- {
539- int64_t i = 0 ;
540- for (IntervalSet interval_set : input_operands) {
541- // TODO(taktoa): we could choose the minimized interval sets more
542- // carefully, since `MinimizeIntervals` is minimizing optimally for each
543- // interval set without the knowledge that other interval sets exist.
544- // For example, we could call `ConvexHull` greedily on the sets
545- // that have the smallest difference between convex hull size and size.
546-
547- // TODO(allight): We might want to distribute the intervals more evenly
548- // then just giving the first 12 operands 5 segments and the rest 1.
549- // Limit exponential growth after 12 parameters. 5^12 = 244 million
550- interval_set = MinimizeIntervals (interval_set, (i < 12 ) ? 5 : 1 );
551- operands.push_back (interval_set);
552- ++i;
553- }
554- }
625+ std::vector<IntervalSet> operands =
626+ ReduceIntervalFragmentationForOp (input_operands);
555627
556628 if (absl::c_all_of (operands,
557629 [](const IntervalSet& i) { return i.IsPrecise (); })) {
0 commit comments