Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 55 additions & 43 deletions src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1317,58 +1317,70 @@ Stage &Stage::fuse(const VarOrRVar &inner, const VarOrRVar &outer, const VarOrRV
debug(4) << "In schedule for " << name() << ", fuse " << outer.name()
<< " and " << inner.name() << " into " << fused.name() << "\n";

// Replace the old dimensions with the new dimension in the dims list
bool found_outer = false, found_inner = false;
string inner_name, outer_name, fused_name;
vector<Dim> &dims = definition.schedule().dims();

DimType outer_type = DimType::PureRVar;
for (size_t i = 0; (!found_outer) && i < dims.size(); i++) {
int inner_pos = -1, outer_pos = -1;
for (int i = 0; i < (int)dims.size(); i++) {
if (dim_match(dims[i], inner)) {
inner_pos = i;
}
if (dim_match(dims[i], outer)) {
found_outer = true;
outer_name = dims[i].var;
outer_type = dims[i].dim_type;
dims.erase(dims.begin() + i);
outer_pos = i;
}
}
if (!found_outer) {
user_error << "In schedule for " << name()
<< ", could not find outer fuse dimension: "
<< outer.name()
<< "\n"
<< dump_argument_list();
}

for (size_t i = 0; (!found_inner) && i < dims.size(); i++) {
if (dim_match(dims[i], inner)) {
found_inner = true;
inner_name = dims[i].var;
fused_name = inner_name + "." + fused.name();
dims[i].var = fused_name;

if (dims[i].dim_type == DimType::ImpureRVar ||
outer_type == DimType::ImpureRVar) {
dims[i].dim_type = DimType::ImpureRVar;
} else if (dims[i].dim_type == DimType::PureRVar ||
outer_type == DimType::PureRVar) {
dims[i].dim_type = DimType::PureRVar;
} else {
dims[i].dim_type = DimType::PureVar;
user_assert(inner_pos >= 0) << "In schedule for " << name()
<< ", could not find inner fuse dimension: "
<< inner.name() << "\n"
<< dump_argument_list();
user_assert(outer_pos >= 0) << "In schedule for " << name()
<< ", could not find outer fuse dimension: "
<< outer.name() << "\n"
<< dump_argument_list();
user_assert(inner_pos != outer_pos) << "In schedule for " << name()
<< ", inner and outer fuse dimensions must be distinct, both are: "
<< inner.name() << "\n";

// The dimensions need to be adjacent before fusing. Verify the reordering is safe.
if (outer_pos != inner_pos + 1) {
vector<VarOrRVar> order;
order.reserve(dims.size() - 1);
for (int i = 0; i < (int)dims.size() - 1; i++) {
if (i == outer_pos) {
continue;
}
order.emplace_back(split_string(dims[i].var, ".").back(), dims[i].is_rvar());
if (i == inner_pos) {
order.emplace_back(split_string(dims[outer_pos].var, ".").back(), dims[outer_pos].is_rvar());
}
}
reorder(order);
for (int i = 0; i < (int)dims.size(); i++) {
if (dim_match(dims[i], inner)) {
inner_pos = i;
break;
}
// We just changed the dim_type without checking the
// for_type. Redundantly re-set the for type on the fused var just
// to trigger validation of the existing for_type.
set_dim_type(fused, dims[i].for_type);
}
outer_pos = inner_pos + 1;
}

if (!found_inner) {
user_error << "In schedule for " << name()
<< ", could not find inner fuse dimension: "
<< inner.name()
<< "\n"
<< dump_argument_list();
string inner_name = dims[inner_pos].var;
string outer_name = dims[outer_pos].var;
string fused_name = inner_name + "." + fused.name();

DimType outer_type = dims[outer_pos].dim_type;
dims.erase(dims.begin() + outer_pos);

dims[inner_pos].var = fused_name;
if (dims[inner_pos].dim_type == DimType::ImpureRVar || outer_type == DimType::ImpureRVar) {
dims[inner_pos].dim_type = DimType::ImpureRVar;
} else if (dims[inner_pos].dim_type == DimType::PureRVar || outer_type == DimType::PureRVar) {
dims[inner_pos].dim_type = DimType::PureRVar;
} else {
dims[inner_pos].dim_type = DimType::PureVar;
}
// We just changed the dim_type without checking the for_type. Redundantly
// re-set the for type on the fused var just to trigger validation of the
// existing for_type.
set_dim_type(fused, dims[inner_pos].for_type);

// Add the fuse to the splits list
Split split = {fused_name, outer_name, inner_name, Expr(), true, TailStrategy::RoundUp, Split::FuseVars};
Expand Down
22 changes: 22 additions & 0 deletions test/correctness/fuse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,28 @@ int main(int argc, char **argv) {
}
}

{
Func f, g;
RDom r(0, 10);
RVar ro, ri, fused;
f(x, y) = 0;
f(x, y) += (x + y + r);
// swap the order when fusing
f.update().split(r, ro, ri, 8).fuse(ro, ri, fused);

g(x, y) = 0;
g(x, y) += (x + y + r);

RDom r2(-16, 32, -16, 32);
Func error;
error() = maximum(abs(f(r2.x, r2.y) - g(r2.x, r2.y)));
int err = evaluate_may_gpu<uint32_t>(error());
if (err != 0) {
printf("Fusion caused a difference in the output\n");
return 1;
}
}

class CheckForMod : public Internal::IRMutator {
using IRMutator::visit;

Expand Down
2 changes: 2 additions & 0 deletions test/error/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ tests(
func_tuple_dim_mismatch.cpp
func_tuple_types_mismatch.cpp
func_tuple_update_types_mismatch.cpp
fuse_same_var.cpp
fuse_vectorized_var_with_rvar.cpp
hoist_storage_extern.cpp
hoist_storage_without_compute_at.cpp
Expand Down Expand Up @@ -110,6 +111,7 @@ tests(
rfactor_inner_dim_non_commutative.cpp
round_up_and_blend_race.cpp
run_with_large_stack_throws.cpp
rvar_fuse_reorder.cpp
shift_inwards_and_blend_race.cpp
specialize_fail.cpp
split_inner_wrong_tail_strategy.cpp
Expand Down
13 changes: 13 additions & 0 deletions test/error/fuse_same_var.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#include "Halide.h"
using namespace Halide;

int main(int argc, char **argv) {
Func f;
Var x, y, fused;

f(x, y) = x + y;
f.fuse(x, x, fused);

printf("Success!\n");
return 0;
}
19 changes: 19 additions & 0 deletions test/error/rvar_fuse_reorder.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include "Halide.h"
#include <stdio.h>

using namespace Halide;

int main(int argc, char **argv) {
Func f("f");
Var x("x");
RDom r(0, 4, 0, 3, "r");
RVar rxo("rxo"), rxi("rxi"), fused("fused");
f(x) = 1;
f(x) = f(x) * 2 + r.x * 5 + r.y;

// this fuse reorders rvars despite the update being non-commutative
f.update().split(r.x, rxo, rxi, 2).fuse(rxi, r.y, fused);

printf("Success!\n");
return 0;
}
Loading