Skip to content

Commit 512468f

Browse files
committed
additional bugfixes and tests
1 parent d558c8f commit 512468f

8 files changed

Lines changed: 102 additions & 76 deletions

src/Function.cpp

Lines changed: 26 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -659,10 +659,6 @@ void Function::define(const vector<string> &args, vector<Expr> values) {
659659
for (auto &value : values) {
660660
value = weakener(value);
661661
}
662-
if (check.reduction_domain.defined()) {
663-
check.reduction_domain.set_predicate(
664-
weakener(check.reduction_domain.predicate()));
665-
}
666662

667663
ReductionDomain rdom;
668664
contents->init_def = Definition(init_def_args, values, rdom, true);
@@ -1109,63 +1105,24 @@ bool Function::has_pure_definition() const {
11091105
}
11101106

11111107
bool Function::is_inductive() const {
1112-
class RecursiveHelper : public IRVisitor {
1113-
using IRVisitor::visit;
1114-
const string &func;
1115-
void visit(const Call *op) override {
1116-
if (op->name == func) {
1117-
recursive = true;
1118-
}
1119-
IRVisitor::visit(op);
1120-
}
1121-
1122-
public:
1123-
bool recursive = false;
1124-
RecursiveHelper(const string &func)
1125-
: func(func) {
1126-
}
1127-
};
1128-
11291108
if (!has_pure_definition()) {
11301109
return false;
11311110
}
11321111

1133-
RecursiveHelper r(name());
1112+
bool recursive = false;
11341113
for (const Expr &e : definition().values()) {
1135-
e.accept(&r);
1114+
visit_with(e, [&](auto *self, const Call *op) {
1115+
if (op->name == name()) {
1116+
recursive = true;
1117+
}
1118+
self->visit_base(op);
1119+
});
11361120
}
11371121

1138-
return r.recursive;
1122+
return recursive;
11391123
}
11401124

11411125
bool Function::is_inductive(const string &var) const {
1142-
class RecursiveHelper : public IRVisitor {
1143-
using IRVisitor::visit;
1144-
const string &func;
1145-
const string &var;
1146-
const int &pos;
1147-
void visit(const Call *op) override {
1148-
if (op->name == func) {
1149-
recursive = true;
1150-
if (const auto &v = op->args[pos].as<Variable>()) {
1151-
if (v->name != var) {
1152-
inductive_in_var = true;
1153-
}
1154-
} else {
1155-
inductive_in_var = true;
1156-
}
1157-
}
1158-
IRVisitor::visit(op);
1159-
}
1160-
1161-
public:
1162-
bool recursive = false;
1163-
bool inductive_in_var = false;
1164-
RecursiveHelper(const string &func, const string &var, const int &pos)
1165-
: func(func), var(var), pos(pos) {
1166-
}
1167-
};
1168-
11691126
if (!has_pure_definition()) {
11701127
return false;
11711128
}
@@ -1178,15 +1135,30 @@ bool Function::is_inductive(const string &var) const {
11781135
}
11791136
}
11801137
}
1138+
11811139
if (pos == -1) {
11821140
return false;
11831141
}
1184-
RecursiveHelper r(name(), var, pos);
1142+
1143+
bool recursive = false;
1144+
bool inductive_in_var = false;
11851145
for (const Expr &e : definition().values()) {
1186-
e.accept(&r);
1146+
visit_with(e, [&](auto *self, const Call *op) {
1147+
if (op->name == name()) {
1148+
recursive = true;
1149+
if (const auto &v = op->args[pos].as<Variable>()) {
1150+
if (v->name != var) {
1151+
inductive_in_var = true;
1152+
}
1153+
} else {
1154+
inductive_in_var = true;
1155+
}
1156+
}
1157+
self->visit_base(op);
1158+
});
11871159
}
11881160

1189-
return r.inductive_in_var;
1161+
return inductive_in_var;
11901162
}
11911163

11921164
bool Function::can_be_inlined() const {

src/Inductive.cpp

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ namespace Internal {
1616
using std::string;
1717
using std::vector;
1818

19+
namespace {
20+
1921
class BaseCaseSolver : public IRVisitor {
2022
using IRVisitor::visit;
2123
const vector<string> &vars;
@@ -31,16 +33,24 @@ class BaseCaseSolver : public IRVisitor {
3133

3234
void visit(const Call *op) override {
3335
if (op->is_intrinsic(Call::if_then_else)) {
36+
if (nested_select == 0) {
37+
visit_with(op->args[0], [&](auto *self, const Call *inner_op) {
38+
user_assert(inner_op->name != func) << "Function " << func << " contains an inductive function reference outside of a select operation value.\n";
39+
self->visit_base(inner_op);
40+
});
41+
}
42+
3443
nested_select += 1;
3544
vector<Interval> old_intervals = condition_intervals;
3645
for (size_t i = 0; i < vars.size(); i++) {
37-
condition_intervals[i] = Interval::make_intersection(old_intervals[i], solve_for_outer_interval(simplify(op->args[0]), vars[i]));
46+
Interval inter = Interval::make_intersection(old_intervals[i], solve_for_outer_interval(simplify(op->args[0]), vars[i]));
47+
condition_intervals[i] = Interval(inter.min, Interval::pos_inf());
3848
bounds.push(vars[i], condition_intervals[i]);
3949
}
40-
4150
op->args[1].accept(this);
4251
for (size_t i = 0; i < vars.size(); i++) {
43-
condition_intervals[i] = Interval::make_intersection(old_intervals[i], solve_for_outer_interval(simplify(!op->args[0]), vars[i]));
52+
Interval inter = Interval::make_intersection(old_intervals[i], solve_for_outer_interval(simplify(!op->args[0]), vars[i]));
53+
condition_intervals[i] = Interval(inter.min, Interval::pos_inf());
4454
bounds.pop(vars[i]);
4555
bounds.push(vars[i], condition_intervals[i]);
4656
}
@@ -51,7 +61,7 @@ class BaseCaseSolver : public IRVisitor {
5161
}
5262
nested_select -= 1;
5363
} else if (op->name == func) {
54-
user_assert(nested_select > 0) << "Function " << func << " contains an inductive function reference outside of a select operation.\n";
64+
user_assert(nested_select > 0) << "Function " << func << " contains an inductive function reference outside of a select operation value.\n";
5565
user_assert(nested_select == 1) << "Function " << func << " contains an inductive function reference inside a nested select operation.\n";
5666
bool found_inductive = false;
5767
for (size_t position = 0; position < vars.size(); position++) {
@@ -67,7 +77,9 @@ class BaseCaseSolver : public IRVisitor {
6777
found_inductive = true;
6878
new_interval = Interval(Interval::neg_inf(), start_box[position].max);
6979
} else {
70-
new_interval = Interval::everything();
80+
std::ostringstream err;
81+
err << "Inductive variable " << vars[position] << " in inductive function " << func << " is not provably monotonically decreasing outside of the base case.";
82+
user_error << err.str() << "\n";
7183
}
7284
new_interval = Interval::make_intersection(new_interval, condition_intervals[position]);
7385
Scope<Interval> i_scope;
@@ -93,7 +105,7 @@ class BaseCaseSolver : public IRVisitor {
93105
}
94106
};
95107

96-
// anonymous namespace
108+
} // anonymous namespace
97109

98110
Box expand_to_include_base_case(const vector<string> &vars, const Expr &RHS, const string &func, const Box &box_required) {
99111
Expr substed = substitute_in_all_lets(RHS);
@@ -109,18 +121,5 @@ Box expand_to_include_base_case(const vector<string> &vars, const Expr &RHS, con
109121
return box2;
110122
}
111123

112-
Box expand_to_include_base_case(const Function &fn, const Box &box_required, const int &pos) {
113-
return expand_to_include_base_case(fn.args(), fn.values()[pos], fn.name(), box_required);
114-
}
115-
116-
Box expand_to_include_base_case(const Function &fn, const Box &box_required) {
117-
Box b = expand_to_include_base_case(fn.args(), fn.values()[0], fn.name(), box_required);
118-
for (size_t pos = 1; pos < fn.values().size(); pos++) {
119-
Box b2 = expand_to_include_base_case(fn.args(), fn.values()[pos], fn.name(), box_required);
120-
merge_boxes(b, b2);
121-
}
122-
return b;
123-
}
124-
125124
} // namespace Internal
126125
} // namespace Halide

src/Inductive.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,6 @@ namespace Internal {
5050
/** Given an initial box for an inductively defined function,
5151
returns an expanded box that includes the function's non-inductive base case. */
5252
Box expand_to_include_base_case(const std::vector<std::string> &vars, const Expr &RHS, const std::string &func, const Box &box_required);
53-
Box expand_to_include_base_case(const Function &fn, const Box &box_required, const int &pos = 0);
54-
Box expand_to_include_base_case(const Function &fn, const Box &box_required);
5553

5654
} // namespace Internal
5755
} // namespace Halide

test/error/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,10 @@ tests(
7474
implicit_args.cpp
7575
impossible_constraints.cpp
7676
incomplete_target.cpp
77-
inductive_loop.cpp
77+
inductive_2d_arbitrary.cpp
78+
inductive_cond_nested.cpp
79+
inductive_cond_self_reference.cpp
80+
inductive_loop_1.cpp
7881
inductive_loop_2.cpp
7982
inductive_loop_3.cpp
8083
inductive_nested_select.cpp
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#include "Halide.h"
2+
#include <stdio.h>
3+
4+
using namespace Halide;
5+
6+
int main(int argc, char **argv) {
7+
Func f("f"), g("g");
8+
9+
Var x("x"), y("y");
10+
11+
f(x, y) = select(x < 2 || y < 2 || y > 6, 0, f(x - 1, y + 1)) + select(x < 2 || x > 6 || y < 2, 0, f(x + 1, y - 1));
12+
g(x, y) = f(x, y) * 2;
13+
14+
g.realize({10, 10});
15+
16+
printf("Success!\n");
17+
return 0;
18+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#include "Halide.h"
2+
#include <stdio.h>
3+
4+
using namespace Halide;
5+
6+
int main(int argc, char **argv) {
7+
Func f("f"), g("g");
8+
9+
Var x("x");
10+
11+
f(x) = select(x < max(2, abs(select(x < 0, f(x - 1), 5))), 0, f(x - 1) + x);
12+
g(x) = f(x) * 2;
13+
14+
g.realize({10});
15+
16+
printf("Success!\n");
17+
return 0;
18+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#include "Halide.h"
2+
#include <stdio.h>
3+
4+
using namespace Halide;
5+
6+
int main(int argc, char **argv) {
7+
Func f("f"), g("g");
8+
9+
Var x("x");
10+
11+
f(x) = select(x < f(x - 1), 0, f(x - 1) + x);
12+
g(x) = f(x) * 2;
13+
14+
g.realize({10});
15+
16+
printf("Success!\n");
17+
return 0;
18+
}

0 commit comments

Comments
 (0)