Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion apps/local_laplacian/local_laplacian_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,8 @@ class LocalLaplacian : public Halide::Generator<LocalLaplacian> {
.compute_root()
.reorder_storage(x, k, y)
.reorder(k, y)
.parallel(y, 8)
.split(y, yo, y, 8)
.parallel(yo)
.vectorize(x, 8);
outGPyramid[j]
.store_at(output, yo)
Expand All @@ -180,6 +181,12 @@ class LocalLaplacian : public Halide::Generator<LocalLaplacian> {
gPyramid[j].never_partition_all();
}
}
gPyramid[0]
.clone_in(gPyramid[1])
.store_at(gPyramid[1], yo)
.compute_at(gPyramid[1], y)
.vectorize(x, 8);

outGPyramid[0]
.compute_at(output, y)
.hoist_storage(output, yo)
Expand Down
56 changes: 55 additions & 1 deletion src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "CodeGen_LLVM.h"
#include "Debug.h"
#include "ExprUsesVar.h"
#include "FindCalls.h"
#include "Func.h"
#include "Function.h"
#include "IR.h"
Expand Down Expand Up @@ -2176,7 +2177,60 @@ Func create_clone_wrapper(Function wrapped_fn, const string &wrapper_name) {
return wrapper;
}

Func get_wrapper(Function wrapped_fn, string wrapper_name, const vector<Func> &fs, bool clone) {
// Walk down the call graph from 'start'. Whenever we find a Func that directly
// calls 'target', record it and stop descending that branch — we don't want to
// pick up unrelated direct callers that happen to live deeper in the subtree.
void collect_direct_callers_of(const Function &target,
const Function &start,
std::set<std::string> &visited,
std::map<std::string, Function> &result) {
if (start.name() == target.name()) {
return;
}
if (!visited.insert(start.name()).second) {
return;
}
std::map<std::string, Function> direct = find_direct_calls(start);
if (direct.count(target.name())) {
result.emplace(start.name(), start);
return;
}
for (const auto &kv : direct) {
collect_direct_callers_of(target, kv.second, visited, result);
}
}

// Expand a user-supplied list of caller Funcs to the set of *direct* callers of
// 'target' that lie on a path from any of those callers down to 'target'.
// Funcs that already directly call 'target' pass through unchanged. If a Func
// has no static path to 'target' at all, leave it alone: the IR may not yet
// reflect a wrapper rewrite from a previous in()/clone_in(), and the existing
// in()/clone_in() semantics permit registering a wrapper for such Funcs.
vector<Func> resolve_transitive_callers(const Function &target, const vector<Func> &fs) {
vector<Func> out;
std::set<std::string> emitted;
auto emit = [&](const Function &g) {
if (emitted.insert(g.name()).second) {
out.emplace_back(g);
}
};
for (const Func &f : fs) {
std::map<std::string, Function> direct_callers;
std::set<std::string> visited;
collect_direct_callers_of(target, f.function(), visited, direct_callers);
if (direct_callers.empty()) {
emit(f.function());
} else {
for (const auto &kv : direct_callers) {
emit(kv.second);
}
}
}
return out;
}

Func get_wrapper(Function wrapped_fn, string wrapper_name, const vector<Func> &fs_in, bool clone) {
vector<Func> fs = fs_in.empty() ? fs_in : resolve_transitive_callers(wrapped_fn, fs_in);
// Either all Funcs in 'fs' have the same wrapper or they don't already
// have any wrappers. Otherwise, throw an error. If 'fs' is empty, then
// it is a global wrapper.
Expand Down
9 changes: 9 additions & 0 deletions src/Func.h
Original file line number Diff line number Diff line change
Expand Up @@ -1347,6 +1347,12 @@ class Func {
for x:
g(x, y) = f(x, y)
\endcode
* If a Func passed to in() does not directly call this Func, in() acts
* transitively: the Func graph is searched downward from each argument,
* and every direct caller of this Func found along the way is wrapped.
* This is useful when intermediate Funcs are anonymous and not held by
* the user (e.g. a pyramid built via helper functions).
*
* using Func::in(), we can write:
\code
f(x, y) = x + y;
Expand Down Expand Up @@ -1398,6 +1404,9 @@ class Func {
h(x, y) = f(x, y) - 3;
\endcode
*
* As with Func::in(), clone_in() acts transitively: any Func in 'f'/'fs'
* that does not directly call this Func is replaced by the set of direct
* callers reachable from it along paths that lead to this Func.
*/
//@{
Func clone_in(const Func &f);
Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ tests(GROUPS correctness multithreaded
ring_buffer.cpp
stream_compaction.cpp
thread_safety.cpp
transitive_in.cpp
truncated_pyramid.cpp
tuple_vector_reduce.cpp
vector_cast.cpp
Expand Down
181 changes: 181 additions & 0 deletions test/correctness/transitive_in.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
#include "Halide.h"
#include "check_call_graphs.h"

#include <cstdio>

using namespace Halide;

namespace {

// Build a small pipeline with anonymous intermediate Funcs, similar in shape
// to local_laplacian's pyramid: we want to call clone_in / in on a non-direct
// caller and have the wrapper be inserted along all paths from that caller
// down to the wrapped Func.
int transitive_clone_in_test() {
Var x("x"), y("y");

Func base("base");
base(x, y) = x + y;

// Two anonymous helpers that each directly call base.
Func helper_a, helper_b;
helper_a(x, y) = base(x, y) + 1;
helper_b(x, y) = base(x, y) * 2;

// top transitively calls base via the helpers, but does not directly.
Func top("top");
top(x, y) = helper_a(x, y) + helper_b(x, y);

// sibling also uses base directly but is *not* on the path from top.
Func sibling("sibling");
sibling(x, y) = base(x, y) - 1;

// Cloning base into top should expand to {helper_a, helper_b}, but must
// leave sibling's call to base untouched.
Func cloned = base.clone_in(top);

Func out("out");
out(x, y) = top(x, y) + sibling(x, y);

base.compute_root();
helper_a.compute_root();
helper_b.compute_root();
cloned.compute_root();
sibling.compute_root();
top.compute_root();

// First check: numerical correctness.
Pipeline p(out);
Buffer<int> result = p.realize({16, 16});
auto check = [](int xv, int yv) {
int b = xv + yv;
int ha = b + 1;
int hb = b * 2;
int t = ha + hb;
int s = b - 1;
return t + s;
};
if (check_image2(result, check) != 0) {
return 1;
}

// Second check: helper_a and helper_b should load from the clone, not
// base; sibling should still load from base.
CheckCalls *checker = new CheckCalls;
Pipeline p2(out);
p2.add_custom_lowering_pass(checker);
p2.compile_to_module(p2.infer_arguments(), "");
const auto &calls = checker->calls;

auto loads_from = [&](const std::string &producer, const std::string &callee) {
auto it = calls.find(producer);
if (it == calls.end()) {
printf("Producer %s not found\n", producer.c_str());
return false;
}
for (const std::string &c : it->second) {
if (c == callee) return true;
}
return false;
};

if (loads_from(helper_a.name(), base.name())) {
printf("helper_a should not directly call base after clone_in\n");
return 1;
}
if (loads_from(helper_b.name(), base.name())) {
printf("helper_b should not directly call base after clone_in\n");
return 1;
}
if (!loads_from(helper_a.name(), cloned.name())) {
printf("helper_a should call the clone\n");
return 1;
}
if (!loads_from(helper_b.name(), cloned.name())) {
printf("helper_b should call the clone\n");
return 1;
}
if (!loads_from(sibling.name(), base.name())) {
printf("sibling should still call base\n");
return 1;
}

return 0;
}

// Direct callers passed to clone_in should still work (no expansion needed).
int direct_clone_in_still_works_test() {
Var x("x"), y("y");
Func f("f"), g("g");
f(x, y) = x + y;
g(x, y) = f(x, y) + 7;
Func cloned = f.clone_in(g);
f.compute_root();
cloned.compute_root();
Buffer<int> r = g.realize({8, 8});
return check_image2(r, [](int xv, int yv) { return xv + yv + 7; });
}

// in() is also transitive.
int transitive_in_test() {
Var x("x"), y("y");
Func base("base");
base(x, y) = x + y;
Func mid;
mid(x, y) = base(x, y) + 3;
Func top("top");
top(x, y) = mid(x, y) * 2;

// base.in(top) should resolve to base.in(mid).
Func wrapper = base.in(top);

base.compute_root();
mid.compute_root();
wrapper.compute_root();
top.compute_root();

Buffer<int> r = top.realize({8, 8});
if (check_image2(r, [](int xv, int yv) { return (xv + yv + 3) * 2; }) != 0) {
return 1;
}

CheckCalls *checker = new CheckCalls;
Pipeline p(top);
p.add_custom_lowering_pass(checker);
p.compile_to_module(p.infer_arguments(), "");
const auto &calls = checker->calls;
auto it = calls.find(mid.name());
if (it == calls.end()) {
printf("mid not found in call graph\n");
return 1;
}
for (const auto &c : it->second) {
if (c == base.name()) {
printf("mid should not directly call base after in()\n");
return 1;
}
}
return 0;
}

} // namespace

int main(int argc, char **argv) {
printf("Running transitive_clone_in_test\n");
if (transitive_clone_in_test() != 0) {
printf("transitive_clone_in_test failed\n");
return 1;
}
printf("Running direct_clone_in_still_works_test\n");
if (direct_clone_in_still_works_test() != 0) {
printf("direct_clone_in_still_works_test failed\n");
return 1;
}
printf("Running transitive_in_test\n");
if (transitive_in_test() != 0) {
printf("transitive_in_test failed\n");
return 1;
}
printf("Success!\n");
return 0;
}
Loading