@@ -16,6 +16,8 @@ namespace Internal {
1616using std::string;
1717using std::vector;
1818
19+ namespace {
20+
1921class BaseCaseSolver : public IRVisitor {
2022 using IRVisitor::visit;
2123 const vector<string> &vars;
@@ -31,16 +33,24 @@ class BaseCaseSolver : public IRVisitor {
3133
3234 void visit (const Call *op) override {
3335 if (op->is_intrinsic (Call::if_then_else)) {
36+ if (nested_select == 0 ) {
37+ visit_with (op->args [0 ], [&](auto *self, const Call *inner_op) {
38+ user_assert (inner_op->name != func) << " Function " << func << " contains an inductive function reference outside of a select operation value.\n " ;
39+ self->visit_base (inner_op);
40+ });
41+ }
42+
3443 nested_select += 1 ;
3544 vector<Interval> old_intervals = condition_intervals;
3645 for (size_t i = 0 ; i < vars.size (); i++) {
37- condition_intervals[i] = Interval::make_intersection (old_intervals[i], solve_for_outer_interval (simplify (op->args [0 ]), vars[i]));
46+ Interval inter = Interval::make_intersection (old_intervals[i], solve_for_outer_interval (simplify (op->args [0 ]), vars[i]));
47+ condition_intervals[i] = Interval (inter.min , Interval::pos_inf ());
3848 bounds.push (vars[i], condition_intervals[i]);
3949 }
40-
4150 op->args [1 ].accept (this );
4251 for (size_t i = 0 ; i < vars.size (); i++) {
43- condition_intervals[i] = Interval::make_intersection (old_intervals[i], solve_for_outer_interval (simplify (!op->args [0 ]), vars[i]));
52+ Interval inter = Interval::make_intersection (old_intervals[i], solve_for_outer_interval (simplify (!op->args [0 ]), vars[i]));
53+ condition_intervals[i] = Interval (inter.min , Interval::pos_inf ());
4454 bounds.pop (vars[i]);
4555 bounds.push (vars[i], condition_intervals[i]);
4656 }
@@ -51,7 +61,7 @@ class BaseCaseSolver : public IRVisitor {
5161 }
5262 nested_select -= 1 ;
5363 } else if (op->name == func) {
54- user_assert (nested_select > 0 ) << " Function " << func << " contains an inductive function reference outside of a select operation.\n " ;
64+ user_assert (nested_select > 0 ) << " Function " << func << " contains an inductive function reference outside of a select operation value .\n " ;
5565 user_assert (nested_select == 1 ) << " Function " << func << " contains an inductive function reference inside a nested select operation.\n " ;
5666 bool found_inductive = false ;
5767 for (size_t position = 0 ; position < vars.size (); position++) {
@@ -67,7 +77,9 @@ class BaseCaseSolver : public IRVisitor {
6777 found_inductive = true ;
6878 new_interval = Interval (Interval::neg_inf (), start_box[position].max );
6979 } else {
70- new_interval = Interval::everything ();
80+ std::ostringstream err;
81+ err << " Inductive variable " << vars[position] << " in inductive function " << func << " is not provably monotonically decreasing outside of the base case." ;
82+ user_error << err.str () << " \n " ;
7183 }
7284 new_interval = Interval::make_intersection (new_interval, condition_intervals[position]);
7385 Scope<Interval> i_scope;
@@ -93,7 +105,7 @@ class BaseCaseSolver : public IRVisitor {
93105 }
94106};
95107
96- // anonymous namespace
108+ } // anonymous namespace
97109
98110Box expand_to_include_base_case (const vector<string> &vars, const Expr &RHS , const string &func, const Box &box_required) {
99111 Expr substed = substitute_in_all_lets (RHS );
@@ -109,18 +121,5 @@ Box expand_to_include_base_case(const vector<string> &vars, const Expr &RHS, con
109121 return box2;
110122}
111123
112- Box expand_to_include_base_case (const Function &fn, const Box &box_required, const int &pos) {
113- return expand_to_include_base_case (fn.args (), fn.values ()[pos], fn.name (), box_required);
114- }
115-
116- Box expand_to_include_base_case (const Function &fn, const Box &box_required) {
117- Box b = expand_to_include_base_case (fn.args (), fn.values ()[0 ], fn.name (), box_required);
118- for (size_t pos = 1 ; pos < fn.values ().size (); pos++) {
119- Box b2 = expand_to_include_base_case (fn.args (), fn.values ()[pos], fn.name (), box_required);
120- merge_boxes (b, b2);
121- }
122- return b;
123- }
124-
125124} // namespace Internal
126125} // namespace Halide
0 commit comments