Skip to content

Commit d5efca0

Browse files
abadamsclaude
andauthored
Fix bounds inference for implicit pure def with RVar args (#9102) (#9103)
* Fix bounds inference for implicit pure def with RVar args (#9102) When a Func's first definition is an update whose LHS uses an RVar directly (e.g. `h(r.x) += ...`), define_base_case synthesized an implicit pure definition but reused the RVar's name for the pure dimension. The resulting name collision caused bounds inference to resolve the update's RVar loop bounds to the pure dimension's output-buffer bounds instead of the RDom's, which broke scheduling directives like vectorize/unroll on the RVar. Treat Variables with a defined reduction_domain the same way we treat Variables with a defined param: generate a fresh pure-arg Var instead of reusing the name. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Use round RDom extent in implicit_pure_def_with_rvar_args test LLVM's ARM64 backend fails to widen the 15-wide int32 vector store produced by the original reproducer's vectorize(r.x) schedule. The bug under test is about bounds inference, not vector widths, so round the RDom extent up to 16 so vectorize lowers to clean NEON stores on every supported target. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 7c05992 commit d5efca0

3 files changed

Lines changed: 99 additions & 1 deletion

File tree

src/Func.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3190,7 +3190,7 @@ Func define_base_case(const Internal::Function &func, const vector<Expr> &a, con
31903190
// Reuse names of existing pure args
31913191
for (size_t i = 0; i < a.size(); i++) {
31923192
if (const Variable *v = a[i].as<Variable>()) {
3193-
if (!v->param.defined()) {
3193+
if (!v->param.defined() && !v->reduction_domain.defined()) {
31943194
pure_args[i] = Var(v->name);
31953195
}
31963196
} else {

test/correctness/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ tests(GROUPS correctness
180180
image_of_lists.cpp
181181
implicit_args.cpp
182182
implicit_args_tests.cpp
183+
implicit_pure_def_with_rvar_args.cpp
183184
in_place.cpp
184185
indexing_access_undef.cpp
185186
infer_arguments.cpp
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
#include "Halide.h"
2+
#include <cstdio>
3+
4+
using namespace Halide;
5+
6+
// Regression test for https://github.com/halide/Halide/issues/9102
7+
//
8+
// When a Func's first definition is an update that uses RVars directly
9+
// as LHS args (e.g. h(r.x) += ...), Halide auto-generates an implicit
10+
// pure definition. The pure dimension must not share a name with the
11+
// RVar, or bounds inference resolves the update's RVar loop bounds to
12+
// the pure dimension's (buffer-driven) bounds instead of the RDom's.
13+
14+
int main(int argc, char **argv) {
15+
Var x;
16+
17+
// Case 1: the original reproducer. vectorize(r.x) requires the RVar
18+
// loop to have a constant extent; before the fix, bounds inference
19+
// produced a symbolic extent from the output buffer and this
20+
// schedule failed to compile.
21+
{
22+
RDom r(0, 16, 0, 8);
23+
Func f{"f"}, g{"g"}, h{"h"};
24+
f(x) = x + 1;
25+
g(x) = 2 * x + 3;
26+
27+
h(r.x) += f(r.x + r.y) * g(r.y);
28+
29+
f.compute_root();
30+
g.compute_root();
31+
h.update().atomic().vectorize(r.x).unroll(r.y);
32+
33+
Buffer<int> out = h.realize({16});
34+
for (int i = 0; i < 16; i++) {
35+
int expected = 0;
36+
for (int j = 0; j < 8; j++) {
37+
expected += (i + j + 1) * (2 * j + 3);
38+
}
39+
if (out(i) != expected) {
40+
printf("Case 1: out(%d) = %d, expected %d\n", i, out(i), expected);
41+
return 1;
42+
}
43+
}
44+
}
45+
46+
// Case 2: same computation, but with an explicit pure definition.
47+
// This was the user's workaround; it must still give the same answer.
48+
{
49+
RDom r(0, 16, 0, 8);
50+
Func f{"f2"}, g{"g2"}, h{"h2"};
51+
f(x) = x + 1;
52+
g(x) = 2 * x + 3;
53+
54+
h(x) = 0;
55+
h(r.x) += f(r.x + r.y) * g(r.y);
56+
57+
f.compute_root();
58+
g.compute_root();
59+
h.update().atomic().vectorize(r.x).unroll(r.y);
60+
61+
Buffer<int> out = h.realize({16});
62+
for (int i = 0; i < 16; i++) {
63+
int expected = 0;
64+
for (int j = 0; j < 8; j++) {
65+
expected += (i + j + 1) * (2 * j + 3);
66+
}
67+
if (out(i) != expected) {
68+
printf("Case 2: out(%d) = %d, expected %d\n", i, out(i), expected);
69+
return 1;
70+
}
71+
}
72+
}
73+
74+
// Case 3: RDom bounds narrower than the realized output. Exercises
75+
// the underlying bounds bug directly (no vectorize needed): without
76+
// a correct loop bound from the RDom, the update would write to
77+
// indices outside the RDom, producing wrong values at the ends.
78+
{
79+
RDom r(2, 5);
80+
Func h{"h3"};
81+
h(r) += cast<int>(r) * 10;
82+
83+
h.update().vectorize(r, 4, TailStrategy::GuardWithIf);
84+
85+
Buffer<int> out = h.realize({10});
86+
for (int i = 0; i < 10; i++) {
87+
int expected = (i >= 2 && i < 7) ? i * 10 : 0;
88+
if (out(i) != expected) {
89+
printf("Case 3: out(%d) = %d, expected %d\n", i, out(i), expected);
90+
return 1;
91+
}
92+
}
93+
}
94+
95+
printf("Success!\n");
96+
return 0;
97+
}

0 commit comments

Comments
 (0)