Skip to content

Commit 997fb66

Browse files
authored
Optimize inject_host_dev_buffer_copies pass (#8996)
* Refine lambda argument requirements in IRMutator and IRVisitor * Early exit in the loop-checking visitor * Compute last_use in-line * Avoid redundant FindBufferUsage in For loop visitor
1 parent 1e8df9b commit 997fb66

3 files changed

Lines changed: 92 additions & 90 deletions

File tree

src/IRMutator.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,11 @@ struct LambdaMutator final : IRMutator {
136136

137137
template<typename T>
138138
auto visit_impl(const T *op) {
139+
// Catch lambdas that accidentally take non-const T* (e.g. For *
140+
// instead of const For *). Such lambdas silently never match.
141+
static_assert(!std::is_invocable_v<decltype(handlers), LambdaMutator *, T *> ||
142+
std::is_invocable_v<decltype(handlers), LambdaMutator *, const T *>,
143+
"mutate_with lambda takes a non-const node pointer; use const T * instead");
139144
if constexpr (std::is_invocable_v<decltype(handlers), LambdaMutator *, const T *>) {
140145
return handlers(this, op);
141146
} else {
@@ -336,9 +341,15 @@ auto mutate_with(const T &ir, Lambdas &&...lambdas) {
336341
return LambdaMutatorGeneric{std::forward<Lambdas>(lambdas)...}.mutate(ir);
337342
} else {
338343
LambdaMutator mutator{std::forward<Lambdas>(lambdas)...};
344+
// Each lambda must take two args: (auto *self, <some-pointer> op).
345+
// Test with const IntImm * (works for auto * params via deduction) and
346+
// nullptr_t (works for specific-type params via implicit conversion).
339347
constexpr bool all_take_two_args =
340-
(std::is_invocable_v<Lambdas, decltype(&mutator), decltype(nullptr)> && ...);
341-
static_assert(all_take_two_args);
348+
((std::is_invocable_v<Lambdas, decltype(&mutator), const IntImm *> ||
349+
std::is_invocable_v<Lambdas, decltype(&mutator), decltype(nullptr)>) &&
350+
...);
351+
static_assert(all_take_two_args,
352+
"All mutate_with lambdas must take two arguments: (auto *self, const T *op)");
342353
return mutator.mutate(ir);
343354
}
344355
}

src/IRVisitor.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,11 @@ struct LambdaVisitor final : IRVisitor {
9898

9999
template<typename T>
100100
auto visit_impl(const T *op) {
101+
// Catch lambdas that accidentally take non-const T* (e.g. For *
102+
// instead of const For *). Such lambdas silently never match.
103+
static_assert(!std::is_invocable_v<decltype(handlers), LambdaVisitor *, T *> ||
104+
std::is_invocable_v<decltype(handlers), LambdaVisitor *, const T *>,
105+
"visit_with lambda takes a non-const node pointer; use const T * instead");
101106
if constexpr (std::is_invocable_v<decltype(handlers), LambdaVisitor *, const T *>) {
102107
return handlers(this, op);
103108
} else {
@@ -255,9 +260,15 @@ struct LambdaVisitor final : IRVisitor {
255260
template<typename... Lambdas>
256261
void visit_with(const IRNode *ir, Lambdas &&...lambdas) {
257262
LambdaVisitor visitor{std::forward<Lambdas>(lambdas)...};
263+
// Each lambda must take two args: (auto *self, <some-pointer> op).
264+
// Test with const IntImm * (works for auto * params via deduction) and
265+
// nullptr_t (works for specific-type params via implicit conversion).
258266
constexpr bool all_take_two_args =
259-
(std::is_invocable_v<Lambdas, decltype(&visitor), decltype(nullptr)> && ...);
260-
static_assert(all_take_two_args);
267+
((std::is_invocable_v<Lambdas, decltype(&visitor), const IntImm *> ||
268+
std::is_invocable_v<Lambdas, decltype(&visitor), decltype(nullptr)>) &&
269+
...);
270+
static_assert(all_take_two_args,
271+
"All visit_with lambdas must take two arguments: (auto *self, const T *op)");
261272
ir->accept(&visitor);
262273
}
263274

src/InjectHostDevBufferCopies.cpp

Lines changed: 66 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -214,10 +214,27 @@ class InjectBufferCopiesForSingleBuffer : public IRMutator {
214214
return call_extern_and_assert("halide_device_free", {buffer_var()});
215215
}
216216

217-
Stmt do_copies(Stmt s) {
217+
Stmt do_copies(Stmt s, FindBufferUsage *precomputed = nullptr) {
218218
// Sniff what happens to the buffer inside the stmt
219-
FindBufferUsage finder(buffer, DeviceAPI::Host);
220-
s.accept(&finder);
219+
FindBufferUsage local_finder(buffer, DeviceAPI::Host);
220+
if (!precomputed) {
221+
s.accept(&local_finder);
222+
precomputed = &local_finder;
223+
}
224+
FindBufferUsage &finder = *precomputed;
225+
226+
// Track the last leaf that uses this buffer, so that the
227+
// caller can inject a device free after it.
228+
if (!finder.devices_touched.empty() ||
229+
!finder.devices_touched_by_extern.empty()) {
230+
last_use = s;
231+
// Block::make flattens nested Blocks, destroying pointer
232+
// identity. Walk to the last non-Block element so that
233+
// inject_free_after_last_use can find it after wrapping.
234+
while (const Block *b = last_use.as<Block>()) {
235+
last_use = b->rest;
236+
}
237+
}
221238

222239
// Insert any appropriate copies/allocations before, and set
223240
// dirty flags after. Do not recurse into the stmt.
@@ -344,15 +361,24 @@ class InjectBufferCopiesForSingleBuffer : public IRMutator {
344361
// The state of the buffer going into the loop is the
345362
// union of the state before the loop starts and the state
346363
// after one iteration. Just forget everything we know.
364+
Stmt saved_last_use = last_use;
365+
last_use = {};
347366
state = State{};
348367
Stmt s = IRMutator::visit(op);
368+
// Collapse inner last_use to the For itself, since a
369+
// device free must go after the entire loop.
370+
if (last_use.defined()) {
371+
last_use = s;
372+
} else {
373+
last_use = saved_last_use;
374+
}
349375
// The state after analyzing the loop body might not be the
350376
// true state if the loop ran for zero iterations. So
351377
// forget everything again.
352378
state = State{};
353379
return s;
354380
} else {
355-
return do_copies(op);
381+
return do_copies(op, &finder);
356382
}
357383
}
358384

@@ -387,8 +413,12 @@ class InjectBufferCopiesForSingleBuffer : public IRMutator {
387413
// host dirties from getting in between blocks of store stmts
388414
// that could be interleaved.
389415
bool has_loops = false;
390-
visit_with(op, [&](auto *, const For *op) {
391-
has_loops = true;
416+
visit_with(op, [&](auto *self, const auto *s) {
417+
if constexpr (std::is_same_v<decltype(s), const For *>) {
418+
has_loops = true;
419+
} else if (!has_loops) {
420+
self->visit_base(s);
421+
}
392422
});
393423
if (has_loops) {
394424
return IRMutator::visit(op);
@@ -402,16 +432,36 @@ class InjectBufferCopiesForSingleBuffer : public IRMutator {
402432
}
403433

404434
Stmt visit(const IfThenElse *op) override {
435+
Stmt saved_last_use = last_use;
436+
437+
last_use = {};
405438
State old = state;
406439
Stmt then_case = mutate(op->then_case);
440+
bool then_touches = last_use.defined();
407441
State then_state = state;
442+
443+
last_use = {};
408444
state = old;
409445
Stmt else_case = mutate(op->else_case);
446+
bool else_touches = last_use.defined();
410447
state.union_with(then_state);
411-
return IfThenElse::make(op->condition, then_case, else_case);
448+
449+
Stmt result = IfThenElse::make(op->condition, then_case, else_case);
450+
// Collapse inner last_use to the IfThenElse itself, since
451+
// a device free must go after the entire conditional.
452+
if (then_touches || else_touches) {
453+
last_use = result;
454+
} else {
455+
last_use = saved_last_use;
456+
}
457+
return result;
412458
}
413459

414460
public:
461+
// The last leaf stmt that uses the buffer. Used by the caller
462+
// to inject a device free after the last use.
463+
Stmt last_use;
464+
415465
InjectBufferCopiesForSingleBuffer(const std::string &b, bool e, MemoryType m)
416466
: buffer(b), is_external(e), memory_type(m) {
417467
if (is_external) {
@@ -427,72 +477,6 @@ class InjectBufferCopiesForSingleBuffer : public IRMutator {
427477
}
428478
};
429479

430-
// Find the last use of a given buffer, which will used later for injecting
431-
// device free calls.
432-
class FindLastUse : public IRVisitor {
433-
public:
434-
Stmt last_use;
435-
436-
FindLastUse(const string &b)
437-
: buffer(b) {
438-
}
439-
440-
private:
441-
string buffer;
442-
443-
using IRVisitor::visit;
444-
445-
void check_and_record_last_use(const Stmt &s) {
446-
// Sniff what happens to the buffer inside the stmt
447-
FindBufferUsage finder(buffer, DeviceAPI::Host);
448-
s.accept(&finder);
449-
450-
if (!finder.devices_touched.empty() ||
451-
!finder.devices_touched_by_extern.empty()) {
452-
last_use = s;
453-
}
454-
}
455-
456-
// We break things down into a serial sequence of leaf
457-
// stmts similar to InjectBufferCopiesForSingleBuffer.
458-
void visit(const For *op) override {
459-
check_and_record_last_use(op);
460-
}
461-
462-
void visit(const Fork *op) override {
463-
check_and_record_last_use(op);
464-
}
465-
466-
void visit(const Evaluate *op) override {
467-
check_and_record_last_use(op);
468-
}
469-
470-
void visit(const LetStmt *op) override {
471-
// If op->value uses the buffer, we need to treat this as a
472-
// single leaf. Otherwise we can recurse.
473-
FindBufferUsage finder(buffer, DeviceAPI::Host);
474-
op->value.accept(&finder);
475-
if (finder.devices_touched.empty() &&
476-
finder.devices_touched_by_extern.empty()) {
477-
IRVisitor::visit(op);
478-
} else {
479-
check_and_record_last_use(op);
480-
}
481-
}
482-
483-
void visit(const AssertStmt *op) override {
484-
check_and_record_last_use(op);
485-
}
486-
487-
void visit(const Store *op) override {
488-
check_and_record_last_use(op);
489-
}
490-
491-
void visit(const IfThenElse *op) override {
492-
check_and_record_last_use(op);
493-
}
494-
};
495-
496480
// Inject the buffer-handling logic for all internal
497481
// allocations. Inputs and outputs are handled below.
498482
class InjectBufferCopies : public IRMutator {
@@ -629,11 +613,9 @@ class InjectBufferCopies : public IRMutator {
629613

630614
// Make a device_and_host_free stmt
631615

632-
FindLastUse last_use(op->name);
633-
body.accept(&last_use);
634-
if (last_use.last_use.defined()) {
616+
if (injector.last_use.defined()) {
635617
Stmt device_free = call_extern_and_assert("halide_device_and_host_free", {buffer});
636-
body = inject_free_after_last_use(body, last_use.last_use, device_free);
618+
body = inject_free_after_last_use(body, injector.last_use, device_free);
637619
}
638620

639621
Expr device_interface = make_device_interface_call(touching_device, op->memory_type);
@@ -646,18 +628,16 @@ class InjectBufferCopies : public IRMutator {
646628
// only touched on device, or touched on multiple
647629
// devices. Do separate device and host allocations.
648630

649-
// Add a device destructor
650-
body = InjectDeviceDestructor(buffer_name).mutate(body);
651-
652-
// Make a device_free stmt
653-
654-
FindLastUse last_use(op->name);
655-
body.accept(&last_use);
656-
if (last_use.last_use.defined()) {
631+
// Inject device_free after the last use. Must happen
632+
// before InjectDeviceDestructor, which modifies the tree.
633+
if (injector.last_use.defined()) {
657634
Stmt device_free = call_extern_and_assert("halide_device_free", {buffer});
658-
body = inject_free_after_last_use(body, last_use.last_use, device_free);
635+
body = inject_free_after_last_use(body, injector.last_use, device_free);
659636
}
660637

638+
// Add a device destructor
639+
body = InjectDeviceDestructor(buffer_name).mutate(body);
640+
661641
Expr condition = op->condition;
662642
bool touched_on_one_device = !touched_on_host && finder.devices_touched.size() == 1 &&
663643
(finder.devices_touched_by_extern.empty() ||

0 commit comments

Comments
 (0)