@@ -214,10 +214,27 @@ class InjectBufferCopiesForSingleBuffer : public IRMutator {
214214 return call_extern_and_assert (" halide_device_free" , {buffer_var ()});
215215 }
216216
217- Stmt do_copies (Stmt s) {
217+ Stmt do_copies (Stmt s, FindBufferUsage *precomputed = nullptr ) {
218218 // Sniff what happens to the buffer inside the stmt
219- FindBufferUsage finder (buffer, DeviceAPI::Host);
220- s.accept (&finder);
219+ FindBufferUsage local_finder (buffer, DeviceAPI::Host);
220+ if (!precomputed) {
221+ s.accept (&local_finder);
222+ precomputed = &local_finder;
223+ }
224+ FindBufferUsage &finder = *precomputed;
225+
226+ // Track the last leaf that uses this buffer, so that the
227+ // caller can inject a device free after it.
228+ if (!finder.devices_touched .empty () ||
229+ !finder.devices_touched_by_extern .empty ()) {
230+ last_use = s;
231+ // Block::make flattens nested Blocks, destroying pointer
232+ // identity. Walk to the last non-Block element so that
233+ // inject_free_after_last_use can find it after wrapping.
234+ while (const Block *b = last_use.as <Block>()) {
235+ last_use = b->rest ;
236+ }
237+ }
221238
222239 // Insert any appropriate copies/allocations before, and set
223240 // dirty flags after. Do not recurse into the stmt.
@@ -344,15 +361,24 @@ class InjectBufferCopiesForSingleBuffer : public IRMutator {
344361 // The state of the buffer going into the loop is the
345362 // union of the state before the loop starts and the state
346363 // after one iteration. Just forget everything we know.
364+ Stmt saved_last_use = last_use;
365+ last_use = {};
347366 state = State{};
348367 Stmt s = IRMutator::visit (op);
368+ // Collapse inner last_use to the For itself, since a
369+ // device free must go after the entire loop.
370+ if (last_use.defined ()) {
371+ last_use = s;
372+ } else {
373+ last_use = saved_last_use;
374+ }
349375 // The state after analyzing the loop body might not be the
350376 // true state if the loop ran for zero iterations. So
351377 // forget everything again.
352378 state = State{};
353379 return s;
354380 } else {
355- return do_copies (op);
381+ return do_copies (op, &finder );
356382 }
357383 }
358384
@@ -387,8 +413,12 @@ class InjectBufferCopiesForSingleBuffer : public IRMutator {
387413 // host dirties from getting in between blocks of store stmts
388414 // that could be interleaved.
389415 bool has_loops = false ;
390- visit_with (op, [&](auto *, const For *op) {
391- has_loops = true ;
416+ visit_with (op, [&](auto *self, const auto *s) {
417+ if constexpr (std::is_same_v<decltype (s), const For *>) {
418+ has_loops = true ;
419+ } else if (!has_loops) {
420+ self->visit_base (s);
421+ }
392422 });
393423 if (has_loops) {
394424 return IRMutator::visit (op);
@@ -402,16 +432,36 @@ class InjectBufferCopiesForSingleBuffer : public IRMutator {
402432 }
403433
404434 Stmt visit (const IfThenElse *op) override {
435+ Stmt saved_last_use = last_use;
436+
437+ last_use = {};
405438 State old = state;
406439 Stmt then_case = mutate (op->then_case );
440+ bool then_touches = last_use.defined ();
407441 State then_state = state;
442+
443+ last_use = {};
408444 state = old;
409445 Stmt else_case = mutate (op->else_case );
446+ bool else_touches = last_use.defined ();
410447 state.union_with (then_state);
411- return IfThenElse::make (op->condition , then_case, else_case);
448+
449+ Stmt result = IfThenElse::make (op->condition , then_case, else_case);
450+ // Collapse inner last_use to the IfThenElse itself, since
451+ // a device free must go after the entire conditional.
452+ if (then_touches || else_touches) {
453+ last_use = result;
454+ } else {
455+ last_use = saved_last_use;
456+ }
457+ return result;
412458 }
413459
414460public:
461+ // The last leaf stmt that uses the buffer. Used by the caller
462+ // to inject a device free after the last use.
463+ Stmt last_use;
464+
415465 InjectBufferCopiesForSingleBuffer (const std::string &b, bool e, MemoryType m)
416466 : buffer(b), is_external(e), memory_type(m) {
417467 if (is_external) {
@@ -427,72 +477,6 @@ class InjectBufferCopiesForSingleBuffer : public IRMutator {
427477 }
428478};
429479
430- // Find the last use of a given buffer, which will used later for injecting
431- // device free calls.
432- class FindLastUse : public IRVisitor {
433- public:
434- Stmt last_use;
435-
436- FindLastUse (const string &b)
437- : buffer(b) {
438- }
439-
440- private:
441- string buffer;
442-
443- using IRVisitor::visit;
444-
445- void check_and_record_last_use (const Stmt &s) {
446- // Sniff what happens to the buffer inside the stmt
447- FindBufferUsage finder (buffer, DeviceAPI::Host);
448- s.accept (&finder);
449-
450- if (!finder.devices_touched .empty () ||
451- !finder.devices_touched_by_extern .empty ()) {
452- last_use = s;
453- }
454- }
455-
456- // We break things down into a serial sequence of leaf
457- // stmts similar to InjectBufferCopiesForSingleBuffer.
458- void visit (const For *op) override {
459- check_and_record_last_use (op);
460- }
461-
462- void visit (const Fork *op) override {
463- check_and_record_last_use (op);
464- }
465-
466- void visit (const Evaluate *op) override {
467- check_and_record_last_use (op);
468- }
469-
470- void visit (const LetStmt *op) override {
471- // If op->value uses the buffer, we need to treat this as a
472- // single leaf. Otherwise we can recurse.
473- FindBufferUsage finder (buffer, DeviceAPI::Host);
474- op->value .accept (&finder);
475- if (finder.devices_touched .empty () &&
476- finder.devices_touched_by_extern .empty ()) {
477- IRVisitor::visit (op);
478- } else {
479- check_and_record_last_use (op);
480- }
481- }
482-
483- void visit (const AssertStmt *op) override {
484- check_and_record_last_use (op);
485- }
486-
487- void visit (const Store *op) override {
488- check_and_record_last_use (op);
489- }
490-
491- void visit (const IfThenElse *op) override {
492- check_and_record_last_use (op);
493- }
494- };
495-
496480// Inject the buffer-handling logic for all internal
497481// allocations. Inputs and outputs are handled below.
498482class InjectBufferCopies : public IRMutator {
@@ -629,11 +613,9 @@ class InjectBufferCopies : public IRMutator {
629613
630614 // Make a device_and_host_free stmt
631615
632- FindLastUse last_use (op->name );
633- body.accept (&last_use);
634- if (last_use.last_use .defined ()) {
616+ if (injector.last_use .defined ()) {
635617 Stmt device_free = call_extern_and_assert (" halide_device_and_host_free" , {buffer});
636- body = inject_free_after_last_use (body, last_use .last_use , device_free);
618+ body = inject_free_after_last_use (body, injector .last_use , device_free);
637619 }
638620
639621 Expr device_interface = make_device_interface_call (touching_device, op->memory_type );
@@ -646,18 +628,16 @@ class InjectBufferCopies : public IRMutator {
646628 // only touched on device, or touched on multiple
647629 // devices. Do separate device and host allocations.
648630
649- // Add a device destructor
650- body = InjectDeviceDestructor (buffer_name).mutate (body);
651-
652- // Make a device_free stmt
653-
654- FindLastUse last_use (op->name );
655- body.accept (&last_use);
656- if (last_use.last_use .defined ()) {
631+ // Inject device_free after the last use. Must happen
632+ // before InjectDeviceDestructor, which modifies the tree.
633+ if (injector.last_use .defined ()) {
657634 Stmt device_free = call_extern_and_assert (" halide_device_free" , {buffer});
658- body = inject_free_after_last_use (body, last_use .last_use , device_free);
635+ body = inject_free_after_last_use (body, injector .last_use , device_free);
659636 }
660637
638+ // Add a device destructor
639+ body = InjectDeviceDestructor (buffer_name).mutate (body);
640+
661641 Expr condition = op->condition ;
662642 bool touched_on_one_device = !touched_on_host && finder.devices_touched .size () == 1 &&
663643 (finder.devices_touched_by_extern .empty () ||
0 commit comments