Skip to content

Commit b7b342c

Browse files
allightcopybara-github
authored andcommitted
Improve interval set reduction for variadic operations.
The interval set reduction for variadic operations is now performed dynamically based on the total number of combinations of intervals across all operands. A new function, `ReduceIntervalFragmentationForOp`, iteratively merges intervals in the input sets, prioritizing merges that result in the smallest loss of precision (i.e., those with the smallest gaps between intervals). This ensures that the number of operations remains below a defined threshold, preventing exponential blow-up, while still maintaining as much precision as possible. Fixes: #4233 PiperOrigin-RevId: 913911845
1 parent 3d0955c commit b7b342c

3 files changed

Lines changed: 114 additions & 21 deletions

File tree

xls/ir/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ cc_library(
248248
"@com_google_absl//absl/container:fixed_array",
249249
"@com_google_absl//absl/log",
250250
"@com_google_absl//absl/log:check",
251+
"@com_google_absl//absl/numeric:int128",
251252
"@com_google_absl//absl/types:span",
252253
],
253254
)

xls/ir/interval_ops.cc

Lines changed: 93 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "absl/container/fixed_array.h"
2929
#include "absl/log/check.h"
3030
#include "absl/log/log.h"
31+
#include "absl/numeric/int128.h"
3132
#include "absl/types/span.h"
3233
#include "xls/common/iter_util.h"
3334
#include "xls/common/iterator_range.h"
@@ -516,6 +517,96 @@ std::optional<IntervalSet> MaybePerformExactCalculation(
516517
return std::move(results);
517518
}
518519

520+
// How many computations are we willing to do for a variadic operation.
521+
static constexpr int64_t kMaxVariadicOperations = 1000000;
522+
523+
struct SafeMultiplyResult {
524+
bool overflow = false;
525+
uint64_t result = 0;
526+
};
527+
528+
// TODO(allight): Replace this with safe_int_ops once it is available.
529+
SafeMultiplyResult SafeMultiply(uint64_t a, uint64_t b) {
530+
absl::uint128 big_a = a;
531+
absl::uint128 big_b = b;
532+
absl::uint128 product = big_a * big_b;
533+
if (product > std::numeric_limits<uint64_t>::max()) {
534+
return {true, 0};
535+
}
536+
return {false, static_cast<uint64_t>(product)};
537+
}
538+
539+
SafeMultiplyResult SafeMultiply(SafeMultiplyResult a, uint64_t b) {
540+
if (a.overflow) {
541+
return {true, 0};
542+
}
543+
return SafeMultiply(a.result, b);
544+
}
545+
546+
// Heuristically reduce the specificity of the input interval sets to avoid
547+
// exponential blow up. We do this by merging intervals which cause the least
548+
// loss of precision iteratively until we are below the threshold.
549+
std::vector<IntervalSet> ReduceIntervalFragmentationForOp(
550+
absl::Span<const IntervalSet> interval_sets) {
551+
// Since we do variadic ops the number of operations is proportional to the
552+
// product of the number of intervals in each interval set.
553+
std::vector<IntervalSet> results(interval_sets.begin(), interval_sets.end());
554+
555+
int64_t max_bit_width = 0;
556+
int64_t total_intervals = 0;
557+
for (const auto& is : interval_sets) {
558+
max_bit_width = std::max(max_bit_width, is.BitCount());
559+
total_intervals += is.NumberOfIntervals();
560+
}
561+
562+
// NB This will iterate at most the total number of intervals in all the
563+
// interval sets since we combine one at a time. We should always break before
564+
// we hit this limit but this is a good correctness check to stop us from
565+
// going on forever.
566+
for (int64_t i = 0; i < total_intervals; ++i) {
567+
auto [overflow, product] = absl::c_accumulate(
568+
results, SafeMultiplyResult{false, 1},
569+
[&](SafeMultiplyResult acc, const IntervalSet& is) {
570+
return SafeMultiply(acc,
571+
static_cast<uint64_t>(is.NumberOfIntervals()));
572+
});
573+
574+
if (!overflow && product <= kMaxVariadicOperations) {
575+
return results;
576+
}
577+
578+
std::optional<int64_t> best_set_idx;
579+
std::optional<Bits> min_gap_size;
580+
581+
// NB We could do this with min_element or something but this lets us keep
582+
// around the best measure too.
583+
for (int64_t i = 0; i < results.size(); ++i) {
584+
const auto& is = results[i];
585+
if (is.NumberOfIntervals() <= 1) {
586+
continue;
587+
}
588+
for (int64_t j = 0; j < is.NumberOfIntervals() - 1; ++j) {
589+
Bits gap_size = bits_ops::Sub(is.Intervals()[j + 1].LowerBound(),
590+
is.Intervals()[j].UpperBound());
591+
Bits extended_gap_size = bits_ops::ZeroExtend(gap_size, max_bit_width);
592+
593+
if (!min_gap_size.has_value() ||
594+
bits_ops::ULessThan(extended_gap_size, *min_gap_size)) {
595+
min_gap_size = extended_gap_size;
596+
best_set_idx = i;
597+
}
598+
}
599+
}
600+
CHECK(best_set_idx.has_value());
601+
602+
results[*best_set_idx] = MinimizeIntervals(
603+
results[*best_set_idx], results[*best_set_idx].NumberOfIntervals() - 1);
604+
}
605+
LOG(FATAL) << "Failed to reduce interval fragmentation. Expected to break "
606+
"before hitting "
607+
<< total_intervals << " iterations";
608+
}
609+
519610
template <typename Calculate>
520611
requires(
521612
std::is_invocable_r_v<OverflowResult, Calculate, absl::Span<Bits const>>)
@@ -528,30 +619,11 @@ IntervalSet PerformVariadicOp(Calculate calc,
528619
std::optional<IntervalSet> exact_result = MaybePerformExactCalculation(
529620
calc, behaviors, input_operands, result_bit_size);
530621
if (exact_result) {
531-
// VLOG(2) << "Got exact results: " << exact_result->ToString();
532622
return *exact_result;
533623
}
534624

535-
std::vector<IntervalSet> operands;
536-
operands.reserve(input_operands.size());
537-
538-
{
539-
int64_t i = 0;
540-
for (IntervalSet interval_set : input_operands) {
541-
// TODO(taktoa): we could choose the minimized interval sets more
542-
// carefully, since `MinimizeIntervals` is minimizing optimally for each
543-
// interval set without the knowledge that other interval sets exist.
544-
// For example, we could call `ConvexHull` greedily on the sets
545-
// that have the smallest difference between convex hull size and size.
546-
547-
// TODO(allight): We might want to distribute the intervals more evenly
548-
// then just giving the first 12 operands 5 segments and the rest 1.
549-
// Limit exponential growth after 12 parameters. 5^12 = 244 million
550-
interval_set = MinimizeIntervals(interval_set, (i < 12) ? 5 : 1);
551-
operands.push_back(interval_set);
552-
++i;
553-
}
554-
}
625+
std::vector<IntervalSet> operands =
626+
ReduceIntervalFragmentationForOp(input_operands);
555627

556628
if (absl::c_all_of(operands,
557629
[](const IntervalSet& i) { return i.IsPrecise(); })) {

xls/ir/interval_ops_test.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,5 +1506,25 @@ FUZZ_TEST(IntervalOpsTest, OneHotZ3Fuzz)
15061506
.WithDomains(IntervalDomain(8),
15071507
fuzztest::ElementOf({LsbOrMsb::kLsb, LsbOrMsb::kMsb}));
15081508

1509+
TEST(IntervalOpsTest, ReduceIntervalFragmentation) {
1510+
// Create a set with 10 separate intervals.
1511+
IntervalSet lhs = FromValues({0, 2, 4, 6, 8, 10, 12, 14, 16, 18}, 8);
1512+
IntervalSet rhs = lhs;
1513+
1514+
IntervalSet result = Add(lhs, rhs);
1515+
1516+
// The result should not be merged into a single interval because we didn't
1517+
// aggressively reduce the inputs.
1518+
EXPECT_GT(result.NumberOfIntervals(), 1);
1519+
1520+
// Also verify that odd numbers in the range are NOT covered,
1521+
// e.g., 7 is not covered because it's not generated by adding even numbers
1522+
// and wasn't merged.
1523+
// It would be nice to be more specific about what we expect but the merging
1524+
// behavior is complicated and we don't want to force particular choices with
1525+
// this test.
1526+
EXPECT_FALSE(result.Covers(UBits(7, 8)));
1527+
}
1528+
15091529
} // namespace
15101530
} // namespace xls::interval_ops

0 commit comments

Comments
 (0)