From 5e5170d5d13cd36b708b20aa4c9d415105ec61cb Mon Sep 17 00:00:00 2001 From: Steven Raphael Date: Thu, 25 Jun 2026 15:58:56 -0400 Subject: [PATCH] Restructure fuse function to test reordering and add tests. Co-authored by: Claude Sonnet 4.6 --- src/Func.cpp | 98 ++++++++++++++++++-------------- test/correctness/fuse.cpp | 22 +++++++ test/error/CMakeLists.txt | 2 + test/error/fuse_same_var.cpp | 13 +++++ test/error/rvar_fuse_reorder.cpp | 19 +++++++ 5 files changed, 111 insertions(+), 43 deletions(-) create mode 100644 test/error/fuse_same_var.cpp create mode 100644 test/error/rvar_fuse_reorder.cpp diff --git a/src/Func.cpp b/src/Func.cpp index f4e4298fae07..81c6c7cfc7fa 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -1317,58 +1317,70 @@ Stage &Stage::fuse(const VarOrRVar &inner, const VarOrRVar &outer, const VarOrRV debug(4) << "In schedule for " << name() << ", fuse " << outer.name() << " and " << inner.name() << " into " << fused.name() << "\n"; - // Replace the old dimensions with the new dimension in the dims list - bool found_outer = false, found_inner = false; - string inner_name, outer_name, fused_name; vector &dims = definition.schedule().dims(); - - DimType outer_type = DimType::PureRVar; - for (size_t i = 0; (!found_outer) && i < dims.size(); i++) { + int inner_pos = -1, outer_pos = -1; + for (int i = 0; i < (int)dims.size(); i++) { + if (dim_match(dims[i], inner)) { + inner_pos = i; + } if (dim_match(dims[i], outer)) { - found_outer = true; - outer_name = dims[i].var; - outer_type = dims[i].dim_type; - dims.erase(dims.begin() + i); + outer_pos = i; } } - if (!found_outer) { - user_error << "In schedule for " << name() - << ", could not find outer fuse dimension: " - << outer.name() - << "\n" - << dump_argument_list(); - } - - for (size_t i = 0; (!found_inner) && i < dims.size(); i++) { - if (dim_match(dims[i], inner)) { - found_inner = true; - inner_name = dims[i].var; - fused_name = inner_name + "." + fused.name(); - dims[i].var = fused_name; - - if (dims[i].dim_type == DimType::ImpureRVar || - outer_type == DimType::ImpureRVar) { - dims[i].dim_type = DimType::ImpureRVar; - } else if (dims[i].dim_type == DimType::PureRVar || - outer_type == DimType::PureRVar) { - dims[i].dim_type = DimType::PureRVar; - } else { - dims[i].dim_type = DimType::PureVar; + user_assert(inner_pos >= 0) << "In schedule for " << name() + << ", could not find inner fuse dimension: " + << inner.name() << "\n" + << dump_argument_list(); + user_assert(outer_pos >= 0) << "In schedule for " << name() + << ", could not find outer fuse dimension: " + << outer.name() << "\n" + << dump_argument_list(); + user_assert(inner_pos != outer_pos) << "In schedule for " << name() + << ", inner and outer fuse dimensions must be distinct, both are: " + << inner.name() << "\n"; + + // The dimensions need to be adjacent before fusing. Verify the reordering is safe. + if (outer_pos != inner_pos + 1) { + vector order; + order.reserve(dims.size() - 1); + for (int i = 0; i < (int)dims.size() - 1; i++) { + if (i == outer_pos) { + continue; + } + order.emplace_back(split_string(dims[i].var, ".").back(), dims[i].is_rvar()); + if (i == inner_pos) { + order.emplace_back(split_string(dims[outer_pos].var, ".").back(), dims[outer_pos].is_rvar()); + } + } + reorder(order); + for (int i = 0; i < (int)dims.size(); i++) { + if (dim_match(dims[i], inner)) { + inner_pos = i; + break; } - // We just changed the dim_type without checking the - // for_type. Redundantly re-set the for type on the fused var just - // to trigger validation of the existing for_type. - set_dim_type(fused, dims[i].for_type); } + outer_pos = inner_pos + 1; } - if (!found_inner) { - user_error << "In schedule for " << name() - << ", could not find inner fuse dimension: " - << inner.name() - << "\n" - << dump_argument_list(); + string inner_name = dims[inner_pos].var; + string outer_name = dims[outer_pos].var; + string fused_name = inner_name + "." + fused.name(); + + DimType outer_type = dims[outer_pos].dim_type; + dims.erase(dims.begin() + outer_pos); + + dims[inner_pos].var = fused_name; + if (dims[inner_pos].dim_type == DimType::ImpureRVar || outer_type == DimType::ImpureRVar) { + dims[inner_pos].dim_type = DimType::ImpureRVar; + } else if (dims[inner_pos].dim_type == DimType::PureRVar || outer_type == DimType::PureRVar) { + dims[inner_pos].dim_type = DimType::PureRVar; + } else { + dims[inner_pos].dim_type = DimType::PureVar; } + // We just changed the dim_type without checking the for_type. Redundantly + // re-set the for type on the fused var just to trigger validation of the + // existing for_type. + set_dim_type(fused, dims[inner_pos].for_type); // Add the fuse to the splits list Split split = {fused_name, outer_name, inner_name, Expr(), true, TailStrategy::RoundUp, Split::FuseVars}; diff --git a/test/correctness/fuse.cpp b/test/correctness/fuse.cpp index d644e6fb741e..3df0c57c6b08 100644 --- a/test/correctness/fuse.cpp +++ b/test/correctness/fuse.cpp @@ -44,6 +44,28 @@ int main(int argc, char **argv) { } } + { + Func f, g; + RDom r(0, 10); + RVar ro, ri, fused; + f(x, y) = 0; + f(x, y) += (x + y + r); + // swap the order when fusing + f.update().split(r, ro, ri, 8).fuse(ro, ri, fused); + + g(x, y) = 0; + g(x, y) += (x + y + r); + + RDom r2(-16, 32, -16, 32); + Func error; + error() = maximum(abs(f(r2.x, r2.y) - g(r2.x, r2.y))); + int err = evaluate_may_gpu(error()); + if (err != 0) { + printf("Fusion caused a difference in the output\n"); + return 1; + } + } + class CheckForMod : public Internal::IRMutator { using IRMutator::visit; diff --git a/test/error/CMakeLists.txt b/test/error/CMakeLists.txt index 79c5bff5063e..0d18a7a87a3a 100644 --- a/test/error/CMakeLists.txt +++ b/test/error/CMakeLists.txt @@ -68,6 +68,7 @@ tests( func_tuple_dim_mismatch.cpp func_tuple_types_mismatch.cpp func_tuple_update_types_mismatch.cpp + fuse_same_var.cpp fuse_vectorized_var_with_rvar.cpp hoist_storage_extern.cpp hoist_storage_without_compute_at.cpp @@ -110,6 +111,7 @@ tests( rfactor_inner_dim_non_commutative.cpp round_up_and_blend_race.cpp run_with_large_stack_throws.cpp + rvar_fuse_reorder.cpp shift_inwards_and_blend_race.cpp specialize_fail.cpp split_inner_wrong_tail_strategy.cpp diff --git a/test/error/fuse_same_var.cpp b/test/error/fuse_same_var.cpp new file mode 100644 index 000000000000..a943b1cafcc7 --- /dev/null +++ b/test/error/fuse_same_var.cpp @@ -0,0 +1,13 @@ +#include "Halide.h" +using namespace Halide; + +int main(int argc, char **argv) { + Func f; + Var x, y, fused; + + f(x, y) = x + y; + f.fuse(x, x, fused); + + printf("Success!\n"); + return 0; +} diff --git a/test/error/rvar_fuse_reorder.cpp b/test/error/rvar_fuse_reorder.cpp new file mode 100644 index 000000000000..dbdfaea34bf4 --- /dev/null +++ b/test/error/rvar_fuse_reorder.cpp @@ -0,0 +1,19 @@ +#include "Halide.h" +#include + +using namespace Halide; + +int main(int argc, char **argv) { + Func f("f"); + Var x("x"); + RDom r(0, 4, 0, 3, "r"); + RVar rxo("rxo"), rxi("rxi"), fused("fused"); + f(x) = 1; + f(x) = f(x) * 2 + r.x * 5 + r.y; + + // this fuse reorders rvars despite the update being non-commutative + f.update().split(r.x, rxo, rxi, 2).fuse(rxi, r.y, fused); + + printf("Success!\n"); + return 0; +}