diff --git a/apps/local_laplacian/local_laplacian_generator.cpp b/apps/local_laplacian/local_laplacian_generator.cpp index f907c8439b96..e00dde32ec15 100644 --- a/apps/local_laplacian/local_laplacian_generator.cpp +++ b/apps/local_laplacian/local_laplacian_generator.cpp @@ -165,7 +165,8 @@ class LocalLaplacian : public Halide::Generator { .compute_root() .reorder_storage(x, k, y) .reorder(k, y) - .parallel(y, 8) + .split(y, yo, y, 8) + .parallel(yo) .vectorize(x, 8); outGPyramid[j] .store_at(output, yo) @@ -180,6 +181,12 @@ class LocalLaplacian : public Halide::Generator { gPyramid[j].never_partition_all(); } } + gPyramid[0] + .clone_in(gPyramid[1]) + .store_at(gPyramid[1], yo) + .compute_at(gPyramid[1], y) + .vectorize(x, 8); + outGPyramid[0] .compute_at(output, y) .hoist_storage(output, yo) diff --git a/src/Func.cpp b/src/Func.cpp index 23d03447ba28..623fe30ad4d5 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -16,6 +16,7 @@ #include "CodeGen_LLVM.h" #include "Debug.h" #include "ExprUsesVar.h" +#include "FindCalls.h" #include "Func.h" #include "Function.h" #include "IR.h" @@ -2176,7 +2177,60 @@ Func create_clone_wrapper(Function wrapped_fn, const string &wrapper_name) { return wrapper; } -Func get_wrapper(Function wrapped_fn, string wrapper_name, const vector &fs, bool clone) { +// Walk down the call graph from 'start'. Whenever we find a Func that directly +// calls 'target', record it and stop descending that branch — we don't want to +// pick up unrelated direct callers that happen to live deeper in the subtree. +void collect_direct_callers_of(const Function &target, + const Function &start, + std::set &visited, + std::map &result) { + if (start.name() == target.name()) { + return; + } + if (!visited.insert(start.name()).second) { + return; + } + std::map direct = find_direct_calls(start); + if (direct.count(target.name())) { + result.emplace(start.name(), start); + return; + } + for (const auto &kv : direct) { + collect_direct_callers_of(target, kv.second, visited, result); + } +} + +// Expand a user-supplied list of caller Funcs to the set of *direct* callers of +// 'target' that lie on a path from any of those callers down to 'target'. +// Funcs that already directly call 'target' pass through unchanged. If a Func +// has no static path to 'target' at all, leave it alone: the IR may not yet +// reflect a wrapper rewrite from a previous in()/clone_in(), and the existing +// in()/clone_in() semantics permit registering a wrapper for such Funcs. +vector resolve_transitive_callers(const Function &target, const vector &fs) { + vector out; + std::set emitted; + auto emit = [&](const Function &g) { + if (emitted.insert(g.name()).second) { + out.emplace_back(g); + } + }; + for (const Func &f : fs) { + std::map direct_callers; + std::set visited; + collect_direct_callers_of(target, f.function(), visited, direct_callers); + if (direct_callers.empty()) { + emit(f.function()); + } else { + for (const auto &kv : direct_callers) { + emit(kv.second); + } + } + } + return out; +} + +Func get_wrapper(Function wrapped_fn, string wrapper_name, const vector &fs_in, bool clone) { + vector fs = fs_in.empty() ? fs_in : resolve_transitive_callers(wrapped_fn, fs_in); // Either all Funcs in 'fs' have the same wrapper or they don't already // have any wrappers. Otherwise, throw an error. If 'fs' is empty, then // it is a global wrapper. diff --git a/src/Func.h b/src/Func.h index b41402582dd5..0bfb591871c7 100644 --- a/src/Func.h +++ b/src/Func.h @@ -1347,6 +1347,12 @@ class Func { for x: g(x, y) = f(x, y) \endcode + * If a Func passed to in() does not directly call this Func, in() acts + * transitively: the Func graph is searched downward from each argument, + * and every direct caller of this Func found along the way is wrapped. + * This is useful when intermediate Funcs are anonymous and not held by + * the user (e.g. a pyramid built via helper functions). + * * using Func::in(), we can write: \code f(x, y) = x + y; @@ -1398,6 +1404,31 @@ class Func { h(x, y) = f(x, y) - 3; \endcode * + * As with Func::in(), clone_in() acts transitively: any Func in 'f'/'fs' + * that does not directly call this Func is replaced by the set of direct + * callers reachable from it along paths that lead to this Func. Only + * this Func is cloned; the intermediate Funcs along the path are not. + * + * For example, given a pipeline that uses sum() (which constructs an + * anonymous inner Func to perform the reduction): + \code + RDom r(0, 5); + f(x, y) = x + y; + g(x, y) = sum(f(x + r, y)); // g calls f via an anonymous Func from sum() + h(x, y) = f(x, y) - 1; // h calls f directly + \endcode + * f.clone_in(g) clones f at the anonymous reduction Func inside g but does + * not clone the reduction Func itself. It is equivalent to this: + \code + RDom r(0, 5); + f(x, y) = x + y; + f_clone(x, y) = x + y; + g(x, y) = sum(f_clone(x + r, y)); // the summation calls the clone + h(x, y) = f(x, y) - 1; // unrelated uses of f are untouched + \endcode + * If the anonymous reduction Func had other consumers besides g, they + * would also see the rewrite from f to f_clone — only this Func is + * cloned, not the intermediates. */ //@{ Func clone_in(const Func &f); diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index efa0bddeb0e8..17ccb738561a 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -422,6 +422,7 @@ tests(GROUPS correctness multithreaded ring_buffer.cpp stream_compaction.cpp thread_safety.cpp + transitive_in.cpp truncated_pyramid.cpp tuple_vector_reduce.cpp vector_cast.cpp diff --git a/test/correctness/transitive_in.cpp b/test/correctness/transitive_in.cpp new file mode 100644 index 000000000000..61c2ec2eb40f --- /dev/null +++ b/test/correctness/transitive_in.cpp @@ -0,0 +1,181 @@ +#include "Halide.h" +#include "check_call_graphs.h" + +#include + +using namespace Halide; + +namespace { + +// Build a small pipeline with anonymous intermediate Funcs, similar in shape +// to local_laplacian's pyramid: we want to call clone_in / in on a non-direct +// caller and have the wrapper be inserted along all paths from that caller +// down to the wrapped Func. +int transitive_clone_in_test() { + Var x("x"), y("y"); + + Func base("base"); + base(x, y) = x + y; + + // Two anonymous helpers that each directly call base. + Func helper_a, helper_b; + helper_a(x, y) = base(x, y) + 1; + helper_b(x, y) = base(x, y) * 2; + + // top transitively calls base via the helpers, but does not directly. + Func top("top"); + top(x, y) = helper_a(x, y) + helper_b(x, y); + + // sibling also uses base directly but is *not* on the path from top. + Func sibling("sibling"); + sibling(x, y) = base(x, y) - 1; + + // Cloning base into top should expand to {helper_a, helper_b}, but must + // leave sibling's call to base untouched. + Func cloned = base.clone_in(top); + + Func out("out"); + out(x, y) = top(x, y) + sibling(x, y); + + base.compute_root(); + helper_a.compute_root(); + helper_b.compute_root(); + cloned.compute_root(); + sibling.compute_root(); + top.compute_root(); + + // First check: numerical correctness. + Pipeline p(out); + Buffer result = p.realize({16, 16}); + auto check = [](int xv, int yv) { + int b = xv + yv; + int ha = b + 1; + int hb = b * 2; + int t = ha + hb; + int s = b - 1; + return t + s; + }; + if (check_image2(result, check) != 0) { + return 1; + } + + // Second check: helper_a and helper_b should load from the clone, not + // base; sibling should still load from base. + CheckCalls *checker = new CheckCalls; + Pipeline p2(out); + p2.add_custom_lowering_pass(checker); + p2.compile_to_module(p2.infer_arguments(), ""); + const auto &calls = checker->calls; + + auto loads_from = [&](const std::string &producer, const std::string &callee) { + auto it = calls.find(producer); + if (it == calls.end()) { + printf("Producer %s not found\n", producer.c_str()); + return false; + } + for (const std::string &c : it->second) { + if (c == callee) return true; + } + return false; + }; + + if (loads_from(helper_a.name(), base.name())) { + printf("helper_a should not directly call base after clone_in\n"); + return 1; + } + if (loads_from(helper_b.name(), base.name())) { + printf("helper_b should not directly call base after clone_in\n"); + return 1; + } + if (!loads_from(helper_a.name(), cloned.name())) { + printf("helper_a should call the clone\n"); + return 1; + } + if (!loads_from(helper_b.name(), cloned.name())) { + printf("helper_b should call the clone\n"); + return 1; + } + if (!loads_from(sibling.name(), base.name())) { + printf("sibling should still call base\n"); + return 1; + } + + return 0; +} + +// Direct callers passed to clone_in should still work (no expansion needed). +int direct_clone_in_still_works_test() { + Var x("x"), y("y"); + Func f("f"), g("g"); + f(x, y) = x + y; + g(x, y) = f(x, y) + 7; + Func cloned = f.clone_in(g); + f.compute_root(); + cloned.compute_root(); + Buffer r = g.realize({8, 8}); + return check_image2(r, [](int xv, int yv) { return xv + yv + 7; }); +} + +// in() is also transitive. +int transitive_in_test() { + Var x("x"), y("y"); + Func base("base"); + base(x, y) = x + y; + Func mid; + mid(x, y) = base(x, y) + 3; + Func top("top"); + top(x, y) = mid(x, y) * 2; + + // base.in(top) should resolve to base.in(mid). + Func wrapper = base.in(top); + + base.compute_root(); + mid.compute_root(); + wrapper.compute_root(); + top.compute_root(); + + Buffer r = top.realize({8, 8}); + if (check_image2(r, [](int xv, int yv) { return (xv + yv + 3) * 2; }) != 0) { + return 1; + } + + CheckCalls *checker = new CheckCalls; + Pipeline p(top); + p.add_custom_lowering_pass(checker); + p.compile_to_module(p.infer_arguments(), ""); + const auto &calls = checker->calls; + auto it = calls.find(mid.name()); + if (it == calls.end()) { + printf("mid not found in call graph\n"); + return 1; + } + for (const auto &c : it->second) { + if (c == base.name()) { + printf("mid should not directly call base after in()\n"); + return 1; + } + } + return 0; +} + +} // namespace + +int main(int argc, char **argv) { + printf("Running transitive_clone_in_test\n"); + if (transitive_clone_in_test() != 0) { + printf("transitive_clone_in_test failed\n"); + return 1; + } + printf("Running direct_clone_in_still_works_test\n"); + if (direct_clone_in_still_works_test() != 0) { + printf("direct_clone_in_still_works_test failed\n"); + return 1; + } + printf("Running transitive_in_test\n"); + if (transitive_in_test() != 0) { + printf("transitive_in_test failed\n"); + return 1; + } + printf("Success!\n"); + return 0; +}