diff --git a/src/RegionCosts.cpp b/src/RegionCosts.cpp index 866767a11dba..8a580e1e931b 100644 --- a/src/RegionCosts.cpp +++ b/src/RegionCosts.cpp @@ -244,10 +244,6 @@ class ExprCost : public IRVisitor { } } - void visit(const Shuffle *op) override { - arith += 1; - } - void visit(const Let *let) override { let->value.accept(this); let->body.accept(this); @@ -255,50 +251,63 @@ class ExprCost : public IRVisitor { // None of the following IR nodes should be encountered when traversing the // IR at the level at which the auto scheduler operates. - void visit(const Load *) override { - internal_error; + void fail(const Expr &e) { + internal_error << "Unexpected Expr while computing region costs: " << e << "\n" + << "Expected front-end Exprs only."; + } + void fail(const Stmt &s) { + internal_error << "Unexpected Stmt while computing region costs:\n" + << s << "\n" + << "Expected front-end Exprs only."; } - void visit(const Ramp *) override { - internal_error; + + void visit(const Load *op) override { + fail(op); + } + void visit(const Ramp *op) override { + fail(op); + } + void visit(const Shuffle *op) override { + fail(op); } - void visit(const Broadcast *) override { - internal_error; + void visit(const Broadcast *op) override { + fail(op); } - void visit(const LetStmt *) override { - internal_error; + void visit(const LetStmt *op) override { + fail(op); } - void visit(const AssertStmt *) override { - internal_error; + void visit(const AssertStmt *op) override { + fail(op); } - void visit(const ProducerConsumer *) override { - internal_error; + void visit(const ProducerConsumer *op) override { + fail(op); } - void visit(const For *) override { - internal_error; + void visit(const For *op) override { + fail(op); } - void visit(const Store *) override { - internal_error; + void visit(const Store *op) override { + fail(op); } - void visit(const Provide *) override { - internal_error; + void visit(const Provide *op) override { + fail(op); } - void visit(const Allocate *) override { - internal_error; + void visit(const Allocate *op) override { + fail(op); } - void visit(const Free *) override { - internal_error; + void visit(const Free *op) override { + fail(op); } - void visit(const Realize *) override { - internal_error; + void visit(const Realize *op) override { + fail(op); } - void visit(const Block *) override { - internal_error; + void visit(const Block *op) override { + fail(op); } - void visit(const IfThenElse *) override { - internal_error; + void visit(const IfThenElse *op) override { + fail(op); } - void visit(const Evaluate *) override { - internal_error; + void visit(const Evaluate *op) override { + fail(op); } public: diff --git a/src/SplitTuples.cpp b/src/SplitTuples.cpp index 3d12513ff6cc..314780d981bd 100644 --- a/src/SplitTuples.cpp +++ b/src/SplitTuples.cpp @@ -170,6 +170,7 @@ class SplitTuples : public IRMutator { could_alias(op->args, store_args)) { deps.insert(op->value_index); } + IRVisitor::visit(op); } bool could_alias(const vector &a, const vector &b) { diff --git a/src/VectorizeLoops.cpp b/src/VectorizeLoops.cpp index 961a541707a8..18243503372b 100644 --- a/src/VectorizeLoops.cpp +++ b/src/VectorizeLoops.cpp @@ -1452,8 +1452,7 @@ class FindVectorizableExprsInAtomicNode : public IRMutator { Stmt visit(const Store *op) override { // A store poisons all subsequent loads, but loads before the // first store can be lifted. - mutate(op->index); - mutate(op->value); + IRMutator::visit(op); poisoned_names.push(op->name); return op; } diff --git a/test/correctness/custom_lowering_pass.cpp b/test/correctness/custom_lowering_pass.cpp index a90d0960a032..2a55a5bc95c3 100644 --- a/test/correctness/custom_lowering_pass.cpp +++ b/test/correctness/custom_lowering_pass.cpp @@ -18,6 +18,7 @@ class CheckForFloatDivision : public IRMutator { std::cerr << "Found floating-point division by constant: " << Expr(op) << "\n"; exit(1); } + IRMutator::visit(op); return op; } }; diff --git a/test/correctness/tuple_reduction.cpp b/test/correctness/tuple_reduction.cpp index 862444f6b655..6e57ab9426ae 100644 --- a/test/correctness/tuple_reduction.cpp +++ b/test/correctness/tuple_reduction.cpp @@ -171,6 +171,46 @@ int main(int argc, char **argv) { } } + { + // A case which requires tuple updates to be atomic, but hides a + // dependence in a way that triggered a bug in the past. + Func f, g; + Var x, y; + + f(x) = Tuple(x + 17, x + 1); + constexpr int w = 100; + + RDom r(0, w); + f(r) = Tuple(f(r)[0] + 5, f(clamp(f(r)[0], 0, w - 1))[1]); + g(x, y) = mux(y, {f(x)[0], f(x)[1]}); + + f.compute_root(); + + Buffer buf = g.realize({w, 2}); + Buffer correct(w, 2); + for (int x = 0; x < w; x++) { + correct(x, 0) = x + 17; + correct(x, 1) = x + 1; + } + for (int r = 0; r < w; r++) { + int new_0 = correct(r, 0) + 5; + int new_1 = correct(std::min(std::max(correct(r, 0), 0), w - 1), 1); + // Tuple element 1 might depend on the old value of tuple element + // zero. The new values must be both computed *then* assigned. + correct(r, 0) = new_0; + correct(r, 1) = new_1; + } + + for (int x = 0; x < w; x++) { + for (int y = 0; y < 2; y++) { + if (buf(x, y) != correct(x, y)) { + printf("buf(%d, %d) = %d instead of %d\n", x, y, buf(x, y), correct(x, y)); + return -1; + } + } + } + } + printf("Success!\n"); return 0; }