-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Expand file tree
/
Copy pathAddParameterChecks.cpp
More file actions
138 lines (109 loc) · 4.02 KB
/
AddParameterChecks.cpp
File metadata and controls
138 lines (109 loc) · 4.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#include "AddParameterChecks.h"
#include "CompilerProfiling.h"
#include "IROperator.h"
#include "IRVisitor.h"
#include "Substitute.h"
#include "Target.h"
namespace Halide {
namespace Internal {
using std::map;
using std::pair;
using std::string;
using std::vector;
namespace {
// Find all the externally referenced scalar parameters
class FindParameters : public IRGraphVisitor {
public:
map<string, Parameter> params;
using IRGraphVisitor::visit;
void visit(const Variable *op) override {
if (op->param.defined()) {
params[op->name] = op->param;
}
}
};
} // namespace
// Insert checks to make sure that parameters are within their
// declared range.
Stmt add_parameter_checks(const vector<Stmt> &preconditions, Stmt s, const Target &t) {
// First, find all the parameters
FindParameters finder;
s.accept(&finder);
map<string, Expr> replace_with_constrained;
vector<pair<string, Expr>> lets;
struct ParamAssert {
Expr condition;
Expr value, limit_value;
string param_name;
};
vector<ParamAssert> asserts;
// Make constrained versions of the params
for (pair<const string, Parameter> &i : finder.params) {
Parameter param = i.second;
if (!param.is_buffer() &&
(param.min_value().defined() ||
param.max_value().defined())) {
string constrained_name = i.first + ".constrained";
Expr constrained_var = Variable::make(param.type(), constrained_name);
Expr constrained_value = Variable::make(param.type(), i.first, param);
replace_with_constrained[i.first] = constrained_var;
if (param.min_value().defined()) {
ParamAssert p = {
constrained_value >= param.min_value(),
constrained_value, param.min_value(),
param.name()};
asserts.push_back(p);
constrained_value = max(constrained_value, param.min_value());
}
if (param.max_value().defined()) {
ParamAssert p = {
constrained_value <= param.max_value(),
constrained_value, param.max_value(),
param.name()};
asserts.push_back(p);
constrained_value = min(constrained_value, param.max_value());
}
lets.emplace_back(constrained_name, constrained_value);
}
}
// Replace the params with their constrained version in the rest of the pipeline
s = substitute(replace_with_constrained, s);
// Inject the let statements
for (const auto &let : lets) {
s = LetStmt::make(let.first, let.second, s);
}
// Inject the assert statements
for (ParamAssert &p : asserts) {
// Upgrade the types to 64-bit versions for the error call
Type wider = p.value.type().with_bits(64);
p.limit_value = cast(wider, p.limit_value);
p.value = cast(wider, p.value);
string error_call_name = "halide_error_param";
if (p.condition.as<LE>()) {
error_call_name += "_too_large";
} else {
internal_assert(p.condition.as<GE>());
error_call_name += "_too_small";
}
if (wider.is_int()) {
error_call_name += "_i64";
} else if (wider.is_uint()) {
error_call_name += "_u64";
} else {
internal_assert(wider.is_float());
error_call_name += "_f64";
}
Expr error = Call::make(Int(32), error_call_name,
{p.param_name, p.value, p.limit_value},
Call::Extern);
s = Block::make(AssertStmt::make(p.condition, error), s);
}
// The unstructured assertions get checked first (because they
// have a custom error message associated with them), so prepend
// them last.
vector<Stmt> stmts = preconditions;
stmts.push_back(s);
return Block::make(stmts);
}
} // namespace Internal
} // namespace Halide