Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions src/Simplify_Exprs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,35 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) {

if (info && op->type.is_int_or_uint()) {
switch (op->op) {
case VectorReduce::Add:
// Alignment of result is the alignment of the arg. Bounds
// of the result can grow according to the reduction
// factor.
info->bounds = cast(op->type, info->bounds * factor);
case VectorReduce::Add: {
// A horizontal add of `factor` lanes is the sum of `factor`
// (possibly distinct) values each in `info->bounds` with
// alignment `info->alignment`. Treating it as multiplication by
// `factor` would be wrong -- that would claim a tighter modulus
// than we actually have. Instead we add the per-lane alignment to
// itself `factor` times.
ModulusRemainder one_lane = info->alignment;
info->bounds = info->bounds * factor;
for (int i = 1; i < factor; i++) {
info->alignment = info->alignment + one_lane;
}
info->cast_to(op->type);
break;
case VectorReduce::SaturatingAdd:
info->bounds = saturating_cast(op->type, info->bounds * factor);
}
case VectorReduce::SaturatingAdd: {
ConstantInterval unsaturated = info->bounds * factor;
if (op->type.can_represent(unsaturated)) {
ModulusRemainder one_lane = info->alignment;
info->bounds = unsaturated;
for (int i = 1; i < factor; i++) {
info->alignment = info->alignment + one_lane;
}
} else {
info->bounds = saturating_cast(op->type, unsaturated);
info->alignment = ModulusRemainder{};
}
break;
}
case VectorReduce::Mul:
// Don't try to infer anything about bounds. Leave the
// alignment unchanged even though we could theoretically
Expand Down
19 changes: 19 additions & 0 deletions test/correctness/simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,25 @@ void check_vectors() {
Expr u8_x = Variable::make(UInt(8), "u8_x");
check(VectorReduce::make(VectorReduce::Add, broadcast(u8_x, 9), 3), broadcast(u8_x * cast(UInt(8), 3), 3));
}

{
// Regression test for https://github.com/halide/Halide/issues/9100.
// Horizontal add of `factor` lanes, each `r (mod m)`, has alignment
// `(factor * r) (mod m)` -- the modulus does NOT scale up, because
// the lanes are summed, not multiplied. Previously the simplifier
// failed to update alignment at all across horizontal add, so a
// cast<uint1> of the result could be folded to the wrong constant.
// A select of broadcasts (which does not rewrite further) is the
// cheapest way to exercise the VectorReduce::Add info-update path.
Expr cond = Variable::make(Bool(), "cond");
Expr lhs = cast(UInt(16), 12203); // odd
Expr rhs = cast(UInt(16), 10637); // odd
Expr inner = Select::make(Broadcast::make(cond, 2),
Broadcast::make(lhs, 2),
Broadcast::make(rhs, 2));
check(cast(UInt(1), VectorReduce::make(VectorReduce::Add, inner, 1)),
cast(UInt(1), 0));
}
}

void check_bounds() {
Expand Down
Loading