-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Expand file tree
/
Copy pathBoundConstantExtentLoops.cpp
More file actions
136 lines (119 loc) · 4.83 KB
/
BoundConstantExtentLoops.cpp
File metadata and controls
136 lines (119 loc) · 4.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#include "BoundConstantExtentLoops.h"
#include "Bounds.h"
#include "CSE.h"
#include "IRMutator.h"
#include "IROperator.h"
#include "Simplify.h"
#include "SimplifyCorrelatedDifferences.h"
#include "Substitute.h"
namespace Halide {
namespace Internal {
namespace {
class BoundLoops : public IRMutator {
protected:
using IRMutator::visit;
std::vector<std::pair<std::string, Expr>> lets;
Stmt visit(const LetStmt *op) override {
if (is_pure(op->value)) {
lets.emplace_back(op->name, op->value);
Stmt s = IRMutator::visit(op);
lets.pop_back();
return s;
} else {
return IRMutator::visit(op);
}
}
std::vector<Expr> facts;
Stmt visit(const IfThenElse *op) override {
facts.push_back(op->condition);
Stmt then_case = mutate(op->then_case);
Stmt else_case;
if (op->else_case.defined()) {
facts.back() = simplify(!op->condition);
else_case = mutate(op->else_case);
}
facts.pop_back();
if (then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return op;
} else {
return IfThenElse::make(op->condition, then_case, else_case);
}
}
Stmt visit(const For *op) override {
Expr extent = simplify(op->extent());
if (is_const(extent)) {
// Nothing needs to be done
return IRMutator::visit(op);
}
if (op->for_type == ForType::Unrolled ||
op->for_type == ForType::Vectorized) {
// Give it one last chance to simplify to an int
Stmt body = op->body;
const IntImm *e = extent.as<IntImm>();
if (e == nullptr) {
// We're about to hard fail. Get really aggressive
// with the simplifier.
for (const auto &[var, value] : reverse_view(lets)) {
extent = Let::make(var, value, extent);
}
extent = remove_likelies(extent);
extent = substitute_in_all_lets(extent);
extent = simplify(extent,
Scope<Interval>::empty_scope(),
Scope<ModulusRemainder>::empty_scope(),
facts);
e = extent.as<IntImm>();
}
Expr extent_upper;
if (e == nullptr) {
// Still no luck. Try taking an upper bound and
// injecting an if statement around the body.
extent_upper = find_constant_bound(extent, Direction::Upper, Scope<Interval>());
if (extent_upper.defined()) {
e = extent_upper.as<IntImm>();
body =
IfThenElse::make(likely_if_innermost(Variable::make(Int(32), op->name) <=
op->max),
body);
}
}
if (e == nullptr && permit_failed_unroll && op->for_type == ForType::Unrolled) {
// Still no luck, but we're allowed to fail. Rewrite
// to a serial loop.
user_warning << "HL_PERMIT_FAILED_UNROLL is allowing us to unroll a non-constant loop into a serial loop. Did you mean to do this?\n";
body = mutate(body);
return For::make(op->name, op->min, op->max,
ForType::Serial, op->partition_policy, op->device_api, std::move(body));
}
user_assert(e)
<< "Can only " << (op->for_type == ForType::Unrolled ? "unroll" : "vectorize")
<< " for loops over a constant extent.\n"
<< "Loop over " << op->name << " has extent " << extent << ".\n";
body = mutate(body);
return For::make(op->name, op->min, (op->min + e) - 1,
op->for_type, op->partition_policy, op->device_api, std::move(body));
} else {
return IRMutator::visit(op);
}
}
bool permit_failed_unroll = false;
public:
BoundLoops() {
// Experimental autoschedulers may want to unroll without
// being totally confident the loop will indeed turn out
// to be constant-sized. If this feature continues to be
// important, we need to expose it in the scheduling
// language somewhere, but how? For now we do something
// ugly and expedient.
// For the tracking issue to fix this, see
// https://github.com/halide/Halide/issues/3479
permit_failed_unroll = get_env_variable("HL_PERMIT_FAILED_UNROLL") == "1";
}
};
} // namespace
Stmt bound_constant_extent_loops(const Stmt &s) {
return BoundLoops()(s);
}
} // namespace Internal
} // namespace Halide