33#include < functional>
44#include < random>
55
6+ #include " random_expr_generator.h"
7+
68// Test the simplifier in Halide by testing for equivalence of randomly generated expressions.
79namespace {
810
@@ -11,241 +13,7 @@ using std::string;
1113using namespace Halide ;
1214using namespace Halide ::Internal;
1315
14- using make_bin_op_fn = Expr (*)(Expr, Expr);
15- using RandomEngine = std::mt19937_64;
16-
17- constexpr int fuzz_var_count = 5 ;
18-
19- Type fuzz_types[] = {UInt (1 ), UInt (8 ), UInt (16 ), UInt (32 ), Int (8 ), Int (16 ), Int (32 )};
20-
21- std::string fuzz_var (int i) {
22- return std::string (1 , ' a' + i);
23- }
24-
25- Expr random_var (RandomEngine &rng, Type t) {
26- std::uniform_int_distribution dist (0 , fuzz_var_count - 1 );
27- int fuzz_count = dist (rng);
28- return cast (t, Variable::make (Int (32 ), fuzz_var (fuzz_count)));
29- }
30-
31- template <typename T>
32- decltype (auto ) random_choice(RandomEngine &rng, T &&choices) {
33- std::uniform_int_distribution<size_t > dist (0 , std::size (choices) - 1 );
34- return choices[dist (rng)];
35- }
36-
37- Type random_type (RandomEngine &rng, int width) {
38- Type t = random_choice (rng, fuzz_types);
39- if (width > 1 ) {
40- t = t.with_lanes (width);
41- }
42- return t;
43- }
44-
45- int get_random_divisor (RandomEngine &rng, Type t) {
46- std::vector<int > divisors = {t.lanes ()};
47- for (int dd = 2 ; dd < t.lanes (); dd++) {
48- if (t.lanes () % dd == 0 ) {
49- divisors.push_back (dd);
50- }
51- }
52-
53- return random_choice (rng, divisors);
54- }
55-
56- Expr random_leaf (RandomEngine &rng, Type t, bool overflow_undef = false , bool imm_only = false ) {
57- if (t.is_int () && t.bits () == 32 ) {
58- overflow_undef = true ;
59- }
60- if (t.is_scalar ()) {
61- if (!imm_only && (rng () & 1 )) {
62- return random_var (rng, t);
63- } else {
64- if (overflow_undef) {
65- // For Int(32), we don't care about correctness during
66- // overflow, so just use numbers that are unlikely to
67- // overflow.
68- return cast (t, (int32_t )((int8_t )(rng () & 255 )));
69- } else {
70- return cast (t, (int32_t )(rng ()));
71- }
72- }
73- } else {
74- int lanes = get_random_divisor (rng, t);
75- if (rng () & 1 ) {
76- auto e1 = random_leaf (rng, t.with_lanes (t.lanes () / lanes), overflow_undef);
77- auto e2 = random_leaf (rng, t.with_lanes (t.lanes () / lanes), overflow_undef);
78- return Ramp::make (e1 , e2 , lanes);
79- } else {
80- auto e1 = random_leaf (rng, t.with_lanes (t.lanes () / lanes), overflow_undef);
81- return Broadcast::make (e1 , lanes);
82- }
83- }
84- }
85-
86- Expr random_expr (RandomEngine &rng, Type t, int depth, bool overflow_undef = false );
87-
88- Expr random_condition (RandomEngine &rng, Type t, int depth, bool maybe_scalar) {
89- static make_bin_op_fn make_bin_op[] = {
90- EQ::make,
91- NE::make,
92- LT::make,
93- LE::make,
94- GT::make,
95- GE::make,
96- };
97-
98- if (maybe_scalar && (rng () & 1 )) {
99- t = t.element_of ();
100- }
101-
102- Expr a = random_expr (rng, t, depth);
103- Expr b = random_expr (rng, t, depth);
104- return random_choice (rng, make_bin_op)(a, b);
105- }
106-
107- Expr make_absd (Expr a, Expr b) {
108- // random_expr() assumes that the result t is the same as the input t,
109- // which isn't true for all absd variants, so force the issue.
110- return cast (a.type (), absd (a, b));
111- }
112-
113- Expr make_bitwise_or (Expr a, Expr b) {
114- return a | b;
115- }
116-
117- Expr make_bitwise_and (Expr a, Expr b) {
118- return a & b;
119- }
120-
121- Expr make_bitwise_xor (Expr a, Expr b) {
122- return a ^ b;
123- }
124-
125- Expr make_abs (Expr a, Expr) {
126- if (!a.type ().is_uint ()) {
127- return cast (a.type (), abs (a));
128- } else {
129- return a;
130- }
131- }
132-
133- Expr make_bitwise_not (Expr a, Expr) {
134- return ~a;
135- }
136-
137- Expr make_shift_right (Expr a, Expr b) {
138- return a >> (b % a.type ().bits ());
139- }
140-
141- Expr random_expr (RandomEngine &rng, Type t, int depth, bool overflow_undef) {
142- if (t.is_int () && t.bits () == 32 ) {
143- overflow_undef = true ;
144- }
145-
146- if (depth-- <= 0 ) {
147- return random_leaf (rng, t, overflow_undef);
148- }
149-
150- std::function<Expr ()> operations[] = {
151- [&]() {
152- return random_leaf (rng, t);
153- },
154- [&]() {
155- auto c = random_condition (rng, t, depth, true );
156- auto e1 = random_expr (rng, t, depth, overflow_undef);
157- auto e2 = random_expr (rng, t, depth, overflow_undef);
158- return select (c, e1 , e2 );
159- },
160- [&]() {
161- if (t.lanes () != 1 ) {
162- int lanes = get_random_divisor (rng, t);
163- auto e1 = random_expr (rng, t.with_lanes (t.lanes () / lanes), depth, overflow_undef);
164- return Broadcast::make (e1 , lanes);
165- }
166- return random_expr (rng, t, depth, overflow_undef);
167- },
168- [&]() {
169- if (t.lanes () != 1 ) {
170- int lanes = get_random_divisor (rng, t);
171- auto e1 = random_expr (rng, t.with_lanes (t.lanes () / lanes), depth, overflow_undef);
172- auto e2 = random_expr (rng, t.with_lanes (t.lanes () / lanes), depth, overflow_undef);
173- return Ramp::make (e1 , e2 , lanes);
174- }
175- return random_expr (rng, t, depth, overflow_undef);
176- },
177- [&]() {
178- if (t.is_bool ()) {
179- auto e1 = random_expr (rng, t, depth);
180- return Not::make (e1 );
181- }
182- return random_expr (rng, t, depth, overflow_undef);
183- },
184- [&]() {
185- // When generating boolean expressions, maybe throw in a condition on non-bool types.
186- if (t.is_bool ()) {
187- return random_condition (rng, random_type (rng, t.lanes ()), depth, false );
188- }
189- return random_expr (rng, t, depth, overflow_undef);
190- },
191- [&]() {
192- // Get a random type that isn't `t` or int32 (int32 can overflow, and we don't care about that).
193- std::vector<Type> subtypes;
194- for (const Type &subtype : fuzz_types) {
195- if (subtype != t && subtype != Int (32 )) {
196- subtypes.push_back (subtype);
197- }
198- }
199- Type subtype = random_choice (rng, subtypes).with_lanes (t.lanes ());
200- return Cast::make (t, random_expr (rng, subtype, depth, overflow_undef));
201- },
202- [&]() {
203- static make_bin_op_fn make_bin_op[] = {
204- // Arithmetic operations.
205- Add::make,
206- Sub::make,
207- Mul::make,
208- Min::make,
209- Max::make,
210- Div::make,
211- Mod::make,
212- };
213-
214- static make_bin_op_fn make_rare_bin_op[] = {
215- make_absd,
216- make_bitwise_or,
217- make_bitwise_and,
218- make_bitwise_xor,
219- make_bitwise_not,
220- make_abs,
221- make_shift_right, // No shift left or we just keep testing integer overflow
222- };
223-
224- Expr a = random_expr (rng, t, depth, overflow_undef);
225- Expr b = random_expr (rng, t, depth, overflow_undef);
226- if ((rng () & 7 ) == 0 ) {
227- return random_choice (rng, make_rare_bin_op)(a, b);
228- } else {
229- return random_choice (rng, make_bin_op)(a, b);
230- }
231- },
232- [&]() {
233- static make_bin_op_fn make_bin_op[] = {
234- And::make,
235- Or::make,
236- };
237-
238- // Boolean operations -- both sides must be cast to booleans,
239- // and then we must cast the result back to 't'.
240- Expr a = random_expr (rng, t, depth, overflow_undef);
241- Expr b = random_expr (rng, t, depth, overflow_undef);
242- Type bool_with_lanes = Bool (t.lanes ());
243- a = cast (bool_with_lanes, a);
244- b = cast (bool_with_lanes, b);
245- return cast (t, random_choice (rng, make_bin_op)(a, b));
246- }};
247- return random_choice (rng, operations)();
248- }
16+ using RandomEngine = RandomExpressionGenerator::RandomEngine;
24917
25018bool test_simplification (Expr a, Expr b, Type t, const map<string, Expr> &vars) {
25119 if (equal (a, b) && !a.same_as (b)) {
@@ -292,12 +60,12 @@ bool test_simplification(Expr a, Expr b, Type t, const map<string, Expr> &vars)
29260 return true ;
29361}
29462
295- bool test_expression (RandomEngine &rng, Expr test, int samples) {
63+ bool test_expression (RandomExpressionGenerator ®, RandomEngine &rng, Expr test, int samples) {
29664 Expr simplified = simplify (test);
29765
29866 map<string, Expr> vars;
299- for (int i = 0 ; i < fuzz_var_count ; i++) {
300- vars[fuzz_var (i)] = Expr ();
67+ for (int i = 0 ; i < ( int )reg. fuzz_vars . size () ; i++) {
68+ vars[reg. fuzz_var (i)] = Expr ();
30169 }
30270
30371 for (int i = 0 ; i < samples; i++) {
@@ -306,7 +74,7 @@ bool test_expression(RandomEngine &rng, Expr test, int samples) {
30674 // Don't let the random leaf depend on v itself.
30775 size_t iterations = 0 ;
30876 do {
309- val = random_leaf (rng, Int (32 ), true );
77+ val = reg. random_leaf (rng, Int (32 ), true );
31078 iterations++;
31179 } while (expr_uses_var (val, var) && iterations < kMaxLeafIterations );
31280 }
@@ -337,6 +105,16 @@ int main(int argc, char **argv) {
337105
338106 auto seed_generator = initialize_rng<RandomEngine>();
339107
108+ RandomExpressionGenerator reg;
109+ reg.fuzz_types = {UInt (1 ), UInt (8 ), UInt (16 ), UInt (32 ), Int (8 ), Int (16 ), Int (32 )};
110+ // FIXME: UInt64 fails!
111+ // FIXME: These need to be disabled (otherwise crashes and/or failures):
112+ // reg.gen_ramp_of_vector = false;
113+ // reg.gen_broadcast_of_vector = false;
114+ reg.gen_vector_reduce = false ;
115+ reg.gen_reinterpret = false ;
116+ reg.gen_shuffles = false ;
117+
340118 for (int i = 0 ; i < ((argc == 1 ) ? 10000 : 1 ); i++) {
341119 auto seed = seed_generator ();
342120 if (argc > 1 ) {
@@ -347,11 +125,12 @@ int main(int argc, char **argv) {
347125 std::cout << " Seed: " << seed << " \n " ;
348126 RandomEngine rng{seed};
349127 std::array<int , 6 > vector_widths = {1 , 2 , 3 , 4 , 6 , 8 };
350- int width = random_choice (rng, vector_widths);
351- Type VT = random_type (rng, width);
128+ int width = reg. random_choice (rng, vector_widths);
129+ Type VT = reg. random_type (rng, width);
352130 // Generate a random expr...
353- Expr test = random_expr (rng, VT, depth);
354- if (!test_expression (rng, test, samples)) {
131+ Expr test = reg.random_expr (rng, VT, depth);
132+ std::cout << test << " \n " ;
133+ if (!test_expression (reg, rng, test, samples)) {
355134
356135 class LimitDepth : public IRMutator {
357136 int limit;
@@ -378,6 +157,7 @@ int main(int argc, char **argv) {
378157 // Failure. Find the minimal subexpression that failed.
379158 std::cout << " Testing subexpressions...\n " ;
380159 class TestSubexpressions : public IRMutator {
160+ RandomExpressionGenerator reg;
381161 RandomEngine &rng;
382162 bool found_failure = false ;
383163
@@ -392,19 +172,19 @@ int main(int argc, char **argv) {
392172 Expr limited;
393173 for (int i = 1 ; i < 4 && !found_failure; i++) {
394174 limited = LimitDepth (i).mutate (e);
395- found_failure = !test_expression (rng, limited, samples);
175+ found_failure = !test_expression (reg, rng, limited, samples);
396176 }
397177 if (!found_failure) {
398- found_failure = !test_expression (rng, e, samples);
178+ found_failure = !test_expression (reg, rng, e, samples);
399179 }
400180 }
401181 return e;
402182 }
403183
404- TestSubexpressions (RandomEngine &rng)
405- : rng(rng) {
184+ TestSubexpressions (RandomExpressionGenerator ®, RandomEngine &rng)
185+ : reg(reg), rng(rng) {
406186 }
407- } tester (rng);
187+ } tester (reg, rng);
408188 tester.mutate (test);
409189
410190 std::cout << " Failed with seed " << seed << " \n " ;
0 commit comments