diff --git a/src/Simplify_Exprs.cpp b/src/Simplify_Exprs.cpp index d6b51a1bd578..44d0eafb60ae 100644 --- a/src/Simplify_Exprs.cpp +++ b/src/Simplify_Exprs.cpp @@ -77,15 +77,35 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) { if (info && op->type.is_int_or_uint()) { switch (op->op) { - case VectorReduce::Add: - // Alignment of result is the alignment of the arg. Bounds - // of the result can grow according to the reduction - // factor. - info->bounds = cast(op->type, info->bounds * factor); + case VectorReduce::Add: { + // A horizontal add of `factor` lanes is the sum of `factor` + // (possibly distinct) values each in `info->bounds` with + // alignment `info->alignment`. Treating it as multiplication by + // `factor` would be wrong -- that would claim a tighter modulus + // than we actually have. Instead we add the per-lane alignment to + // itself `factor` times. + ModulusRemainder one_lane = info->alignment; + info->bounds = info->bounds * factor; + for (int i = 1; i < factor; i++) { + info->alignment = info->alignment + one_lane; + } + info->cast_to(op->type); break; - case VectorReduce::SaturatingAdd: - info->bounds = saturating_cast(op->type, info->bounds * factor); + } + case VectorReduce::SaturatingAdd: { + ConstantInterval unsaturated = info->bounds * factor; + if (op->type.can_represent(unsaturated)) { + ModulusRemainder one_lane = info->alignment; + info->bounds = unsaturated; + for (int i = 1; i < factor; i++) { + info->alignment = info->alignment + one_lane; + } + } else { + info->bounds = saturating_cast(op->type, unsaturated); + info->alignment = ModulusRemainder{}; + } break; + } case VectorReduce::Mul: // Don't try to infer anything about bounds. Leave the // alignment unchanged even though we could theoretically diff --git a/test/correctness/simplify.cpp b/test/correctness/simplify.cpp index ff934e1b82ba..a0ffe8507bed 100644 --- a/test/correctness/simplify.cpp +++ b/test/correctness/simplify.cpp @@ -881,6 +881,25 @@ void check_vectors() { Expr u8_x = Variable::make(UInt(8), "u8_x"); check(VectorReduce::make(VectorReduce::Add, broadcast(u8_x, 9), 3), broadcast(u8_x * cast(UInt(8), 3), 3)); } + + { + // Regression test for https://github.com/halide/Halide/issues/9100. + // Horizontal add of `factor` lanes, each `r (mod m)`, has alignment + // `(factor * r) (mod m)` -- the modulus does NOT scale up, because + // the lanes are summed, not multiplied. Previously the simplifier + // failed to update alignment at all across horizontal add, so a + // cast of the result could be folded to the wrong constant. + // A select of broadcasts (which does not rewrite further) is the + // cheapest way to exercise the VectorReduce::Add info-update path. + Expr cond = Variable::make(Bool(), "cond"); + Expr lhs = cast(UInt(16), 12203); // odd + Expr rhs = cast(UInt(16), 10637); // odd + Expr inner = Select::make(Broadcast::make(cond, 2), + Broadcast::make(lhs, 2), + Broadcast::make(rhs, 2)); + check(cast(UInt(1), VectorReduce::make(VectorReduce::Add, inner, 1)), + cast(UInt(1), 0)); + } } void check_bounds() {