Skip to content

Commit cb661ed

Browse files
abadamsclaude
andauthored
Fix SkipStages narrowing loaded by Select condition (#9152)
SkipStages::visit(Select) was combining each branch's `loaded` predicate with the select condition (via make_select / make_and). That's wrong: Halide's Select evaluates both branches and only picks one of the results, so any load inside either branch fires unconditionally. The `loaded` predicate must be the OR of both branches, ungated by the condition. The bug caused allocation bounds inference to size affected buffers down to zero whenever the runtime condition was false, while the generated code still emitted a vectorized load from them -- a heap OOB read that showed up intermittently as a valgrind use-after-free on the truncated_pyramid test. Keep `used` gated by the condition as before (with the same select/and collapse that fixed the exponential blow-up in #9147). Add a regression test in skip_stages.cpp that records the minimum producer allocation size through a custom malloc handler and verifies it's non-zero when the producer's only consumer is inside a runtime-false select branch. Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
1 parent 642087f commit cb661ed

2 files changed

Lines changed: 62 additions & 9 deletions

File tree

src/SkipStages.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -488,11 +488,6 @@ class SkipStages : public IRMutator {
488488
mutate(op->false_value);
489489
// func_info now holds the false-branch info.
490490

491-
// Ids touched on both branches: combine with select(cond, t, f),
492-
// so make_select can collapse `select(c, X, X) -> X` when the
493-
// two branches contributed the same Expr.
494-
// Ids touched on only one branch: AND the predicate with the
495-
// appropriate side of the condition.
496491
auto merge_into_old = [&](size_t id, const Expr &u, const Expr &l) {
497492
auto [q, inserted] = old.try_emplace(id, FuncInfo{u, l});
498493
if (!inserted) {
@@ -505,13 +500,16 @@ class SkipStages : public IRMutator {
505500
size_t id = p.first;
506501
auto it_t = true_info.find(id);
507502
Expr u, l;
503+
// Select evaluates both sides but only uses one, so `loaded` is
504+
// the unconditional OR of the two branches while `used` is gated
505+
// by the condition.
508506
if (it_t != true_info.end()) {
509507
u = make_select(op->condition, it_t->second.used, p.second.used);
510-
l = make_select(op->condition, it_t->second.loaded, p.second.loaded);
508+
l = make_or(it_t->second.loaded, p.second.loaded);
511509
true_info.erase(it_t);
512510
} else {
513511
u = make_and(p.second.used, !op->condition);
514-
l = make_and(p.second.loaded, !op->condition);
512+
l = p.second.loaded;
515513
}
516514
merge_into_old(id, u, l);
517515
}
@@ -521,7 +519,7 @@ class SkipStages : public IRMutator {
521519
for (auto &p : true_info) {
522520
size_t id = p.first;
523521
Expr u = make_and(p.second.used, op->condition);
524-
Expr l = make_and(p.second.loaded, op->condition);
522+
Expr l = p.second.loaded;
525523
merge_into_old(id, u, l);
526524
}
527525
func_info.clear();

test/correctness/skip_stages.cpp

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,22 @@ extern "C" HALIDE_EXPORT_SYMBOL int call_counter(int x, int idx) {
1010
}
1111
HalideExtern_2(int, call_counter, int, int);
1212

13+
size_t min_alloc_size = (size_t)-1;
14+
void *recording_malloc(JITUserContext *, size_t x) {
15+
if (x < min_alloc_size) {
16+
min_alloc_size = x;
17+
}
18+
// Over-allocate so we can return a 32-byte-aligned pointer; stash the
19+
// original so we can free it later.
20+
void *orig = malloc(x + 32);
21+
void *ptr = (void *)((((size_t)orig + 32) >> 5) << 5);
22+
((void **)ptr)[-1] = orig;
23+
return ptr;
24+
}
25+
void recording_free(JITUserContext *, void *p) {
26+
free(((void **)p)[-1]);
27+
}
28+
1329
void reset_counts() {
1430
for (int i = 0; i < 4; i++) {
1531
call_count[i] = 0;
@@ -245,7 +261,7 @@ int main(int argc, char **argv) {
245261
}
246262

247263
{
248-
// Check the iteration with storage hoisting
264+
// Check the interaction with storage hoisting
249265

250266
// This Func may or may not be loaded, depending on y
251267
Func maybe_loaded("maybe_loaded");
@@ -268,6 +284,45 @@ int main(int argc, char **argv) {
268284
output.realize({100, 100});
269285
}
270286

287+
{
288+
// Regression test: a Func loaded only inside a select branch whose
289+
// condition is false at runtime should still get a non-zero
290+
// allocation, because Halide's Select evaluates both branches
291+
// (the load fires regardless of the condition). The producer's
292+
// compute body can still be skipped because `used` is correctly
293+
// gated.
294+
Func producer("producer");
295+
producer(x) = call_counter(x, 0);
296+
297+
Func consumer("consumer");
298+
consumer(x) = select(toggle1, producer(x) + producer(x + 1), 42);
299+
300+
producer.compute_root().vectorize(x, 8);
301+
consumer.vectorize(x, 8);
302+
303+
consumer.jit_handlers().custom_malloc = recording_malloc;
304+
consumer.jit_handlers().custom_free = recording_free;
305+
306+
reset_counts();
307+
toggle1.set(false);
308+
min_alloc_size = (size_t)-1;
309+
Buffer<int> out = consumer.realize({64});
310+
311+
// Producer must be skipped when toggle is false.
312+
check_counts(0);
313+
// ...but the allocation must still happen, sized for the load.
314+
if (min_alloc_size == 0) {
315+
printf("Producer allocation was zero-sized; unconditional load would OOB\n");
316+
exit(1);
317+
}
318+
for (int i = 0; i < 64; i++) {
319+
if (out(i) != 42) {
320+
printf("out(%d) = %d, expected 42\n", i, out(i));
321+
exit(1);
322+
}
323+
}
324+
}
325+
271326
printf("Success!\n");
272327
return 0;
273328
}

0 commit comments

Comments
 (0)