-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Expand file tree
/
Copy pathRegionCosts.cpp
More file actions
844 lines (733 loc) · 27.4 KB
/
RegionCosts.cpp
File metadata and controls
844 lines (733 loc) · 27.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
#include "RegionCosts.h"
#include "FindCalls.h"
#include "Function.h"
#include "IRMutator.h"
#include "IROperator.h"
#include "IRVisitor.h"
#include "PartitionLoops.h"
#include "RealizationOrder.h"
#include "Simplify.h"
namespace Halide {
namespace Internal {
using std::map;
using std::set;
using std::string;
using std::vector;
namespace {
// Visitor for keeping track of all input images accessed and their types.
class FindImageInputs : public IRVisitor {
using IRVisitor::visit;
set<string> seen_image_param;
void visit(const Call *call) override {
if (call->call_type == Call::Image) {
input_type[call->name] = call->type;
// Call to an ImageParam
if (call->param.defined() && (seen_image_param.count(call->name) == 0)) {
for (int i = 0; i < call->param.dimensions(); ++i) {
const Expr &min = call->param.min_constraint_estimate(i);
const Expr &extent = call->param.extent_constraint_estimate(i);
user_assert(min.defined())
<< "AutoSchedule: Estimate of the min value of ImageParam \""
<< call->name << "\" in dimension " << i << " is not specified.\n";
user_assert(extent.defined())
<< "AutoSchedule: Estimate of the extent value of ImageParam \""
<< call->name << "\" in dimension " << i << " is not specified.\n";
string min_var = call->param.name() + ".min." + std::to_string(i);
string extent_var = call->param.name() + ".extent." + std::to_string(i);
input_estimates.emplace(min_var, Interval(min, min));
input_estimates.emplace(extent_var, Interval(extent, extent));
seen_image_param.insert(call->name);
}
}
}
for (const auto &arg : call->args) {
arg.accept(this);
}
}
public:
map<string, Type> input_type;
map<string, Interval> input_estimates;
};
// Visitor for tracking the arithmetic and memory costs.
class ExprCost : public IRVisitor {
using IRVisitor::visit;
// Immediate values and variables do not incur any cost.
void visit(const IntImm *) override {
}
void visit(const UIntImm *) override {
}
void visit(const FloatImm *) override {
}
void visit(const StringImm *) override {
}
void visit(const Variable *) override {
}
void visit(const Cast *op) override {
op->value.accept(this);
arith += 1;
}
void visit(const Reinterpret *op) override {
op->value.accept(this);
// `Reinterpret` is a no-op and does *not* incur any cost.
}
template<typename T>
void visit_binary_operator(const T *op, int op_cost) {
op->a.accept(this);
op->b.accept(this);
arith += op_cost;
}
// The costs of all the simple binary operations is set to one.
// TODO: Changing the costs for division and multiplication may be
// beneficial. Write a test case to validate this and update the costs
// accordingly.
void visit(const Add *op) override {
visit_binary_operator(op, 1);
}
void visit(const Sub *op) override {
visit_binary_operator(op, 1);
}
void visit(const Mul *op) override {
visit_binary_operator(op, 1);
}
void visit(const Div *op) override {
visit_binary_operator(op, 1);
}
void visit(const Mod *op) override {
visit_binary_operator(op, 1);
}
void visit(const Min *op) override {
visit_binary_operator(op, 1);
}
void visit(const Max *op) override {
visit_binary_operator(op, 1);
}
void visit(const EQ *op) override {
visit_binary_operator(op, 1);
}
void visit(const NE *op) override {
visit_binary_operator(op, 1);
}
void visit(const LT *op) override {
visit_binary_operator(op, 1);
}
void visit(const LE *op) override {
visit_binary_operator(op, 1);
}
void visit(const GT *op) override {
visit_binary_operator(op, 1);
}
void visit(const GE *op) override {
visit_binary_operator(op, 1);
}
void visit(const And *op) override {
visit_binary_operator(op, 1);
}
void visit(const Or *op) override {
visit_binary_operator(op, 1);
}
void visit(const Not *op) override {
op->a.accept(this);
arith += 1;
}
void visit(const Select *op) override {
op->condition.accept(this);
op->true_value.accept(this);
op->false_value.accept(this);
arith += 1;
}
void visit(const Call *call) override {
if (call->is_intrinsic(Call::if_then_else)) {
internal_assert(call->args.size() == 2 || call->args.size() == 3);
int64_t current_arith = arith, current_memory = memory;
arith = 0, memory = 0;
if (call->args.size() == 3) {
call->args[2].accept(this);
}
// Check if this if_then_else is because of tracing or print_when.
// If it is, we should only take into account the cost of computing
// the false expr since the true expr is debugging/tracing code.
const Call *true_value_call = call->args[1].as<Call>();
if (!true_value_call || !true_value_call->is_intrinsic(Call::return_second)) {
int64_t false_cost_arith = arith;
int64_t false_cost_memory = memory;
// For if_then_else intrinsic, the cost is the max of true and
// false branch costs plus the predicate cost.
arith = 0, memory = 0;
call->args[0].accept(this);
int64_t pred_cost_arith = arith;
int64_t pred_cost_memory = memory;
arith = 0, memory = 0;
call->args[1].accept(this);
int64_t true_cost_arith = arith;
int64_t true_cost_memory = memory;
arith = pred_cost_arith + std::max(true_cost_arith, false_cost_arith);
memory = pred_cost_memory + std::max(true_cost_memory, false_cost_memory);
}
arith += current_arith;
memory += current_memory;
return;
} else if (call->is_intrinsic(Call::return_second)) {
// For return_second, since the first expr would usually either be a
// print_when or tracing, we should only take into account the cost
// of computing the second expr.
internal_assert(call->args.size() == 2);
call->args[1].accept(this);
return;
}
if (call->call_type == Call::Halide || call->call_type == Call::Image) {
// Each call also counts as an op since it results in a load instruction.
arith += 1;
memory += call->type.bytes();
detailed_byte_loads[call->name] += (int64_t)call->type.bytes();
} else if (call->is_extern()) {
// TODO: Suffix based matching is kind of sketchy; but going ahead with
// it for now. Also not all the PureExtern's are accounted for yet.
if (ends_with(call->name, "_f64")) {
arith += 20;
} else if (ends_with(call->name, "_f32")) {
arith += 10;
} else if (ends_with(call->name, "_f16")) {
arith += 5;
} else {
// There is no visibility into an extern stage so there is no
// way to know the cost of the call statically. Modeling the
// cost of an extern stage requires profiling or user annotation.
user_warning << "Unknown extern call " << call->name << "\n";
}
} else if (call->is_intrinsic()) {
if (call->is_intrinsic({Call::lerp,
Call::div_round_to_zero,
Call::mod_round_to_zero,
Call::random})) {
// It's an expensive arithmetic intrinsic
arith += 5;
} else if (Call::as_tag(call) ||
call->is_intrinsic({Call::promise_clamped,
Call::unsafe_promise_clamped,
Call::undef})) {
// These intrinsics entail no actual work
} else {
// For other intrinsics (e.g. bitwise ops, fixed-point math),
// use 1 for the arithmetic cost.
arith += 1;
}
}
for (const auto &arg : call->args) {
arg.accept(this);
}
}
void visit(const Let *let) override {
let->value.accept(this);
let->body.accept(this);
}
// None of the following IR nodes should be encountered when traversing the
// IR at the level at which the auto scheduler operates.
void fail(const Expr &e) {
internal_error << "Unexpected Expr while computing region costs: " << e << "\n"
<< "Expected front-end Exprs only.";
}
void fail(const Stmt &s) {
internal_error << "Unexpected Stmt while computing region costs:\n"
<< s << "\n"
<< "Expected front-end Exprs only.";
}
void visit(const Load *op) override {
fail(op);
}
void visit(const Ramp *op) override {
fail(op);
}
void visit(const Shuffle *op) override {
fail(op);
}
void visit(const Broadcast *op) override {
fail(op);
}
void visit(const LetStmt *op) override {
fail(op);
}
void visit(const AssertStmt *op) override {
fail(op);
}
void visit(const ProducerConsumer *op) override {
fail(op);
}
void visit(const For *op) override {
fail(op);
}
void visit(const Store *op) override {
fail(op);
}
void visit(const Provide *op) override {
fail(op);
}
void visit(const Allocate *op) override {
fail(op);
}
void visit(const Free *op) override {
fail(op);
}
void visit(const Realize *op) override {
fail(op);
}
void visit(const Block *op) override {
fail(op);
}
void visit(const IfThenElse *op) override {
fail(op);
}
void visit(const Evaluate *op) override {
fail(op);
}
public:
int64_t arith = 0;
int64_t memory = 0;
// Detailed breakdown of bytes loaded by the allocation or function
// they are loaded from.
map<string, int64_t> detailed_byte_loads;
ExprCost() = default;
};
// Return the number of bytes required to store a single value of the
// function.
Expr get_func_value_size(const Function &f) {
Expr size = 0;
const vector<Type> &types = f.output_types();
internal_assert(!types.empty());
for (auto type : types) {
size += type.bytes();
}
return simplify(size);
}
// Helper class that only accounts for the likely portion of the expression in
// the case of max, min, and select. This will help costing functions with
// boundary conditions better. The likely intrinsic triggers loop partitioning
// and on average (steady stage) the cost of the expression will be equivalent
// to the likely portion.
//
// TODO: Comment this out for now until we modify the compute expr cost and
// detailed byte loads functions to account for likely exprs.
/*class LikelyExpression : public IRMutator {
using IRMutator::visit;
Expr visit(const Min *op) override {
IRVisitor::visit(op);
bool likely_a = has_likely_tag(op->a);
bool likely_b = has_likely_tag(op->b);
if (likely_a && !likely_b) {
return op->a;
} else if (likely_b && !likely_a) {
return op->a;
}
}
Expr visit(const Max *op) override {
IRVisitor::visit(op);
bool likely_a = has_likely_tag(op->a);
bool likely_b = has_likely_tag(op->b);
if (likely_a && !likely_b) {
return op->a;
} else if (likely_b && !likely_a) {
return op->b;
}
}
Expr visit(const Select *op) override {
IRVisitor::visit(op);
bool likely_t = has_likely_tag(op->true_value);
bool likely_f = has_likely_tag(op->false_value);
if (likely_t && !likely_f) {
return op->true_value;
} else if (likely_f && !likely_t) {
return op->false_value;
}
}
};*/
Cost compute_expr_cost(Expr expr) {
// TODO: Handle likely
// expr = LikelyExpression().mutate(expr);
expr = simplify(expr);
ExprCost cost_visitor;
expr.accept(&cost_visitor);
return Cost(cost_visitor.arith, cost_visitor.memory);
}
map<string, Expr> compute_expr_detailed_byte_loads(Expr expr) {
// TODO: Handle likely
// expr = LikelyExpression().mutate(expr);
expr = simplify(expr);
ExprCost cost_visitor;
expr.accept(&cost_visitor);
map<string, Expr> loads;
for (const auto &iter : cost_visitor.detailed_byte_loads) {
loads.emplace(iter.first, Expr(iter.second));
}
return loads;
}
} // anonymous namespace
RegionCosts::RegionCosts(const map<string, Function> &_env,
const vector<string> &_order)
: env(_env), order(_order) {
for (const auto &kv : env) {
// Pre-compute the function costs without any inlining.
func_cost[kv.first] = get_func_cost(kv.second);
// Get the types of all the image inputs to the pipeline, including
// their estimated min/extent values if applicable (i.e. if they are
// ImageParam).
FindImageInputs find;
kv.second.accept(&find);
for (const auto &in : find.input_type) {
inputs[in.first] = in.second;
}
for (const auto &iter : find.input_estimates) {
input_estimates.push(iter.first, iter.second);
}
}
}
Cost RegionCosts::stage_region_cost(const string &func, int stage, const DimBounds &bounds,
const set<string> &inlines) {
Function curr_f = get_element(env, func);
Box stage_region;
const vector<Dim> &dims = get_stage_dims(curr_f, stage);
for (int d = 0; d < (int)dims.size() - 1; d++) {
stage_region.push_back(get_element(bounds, dims[d].var));
}
Expr size = box_size(stage_region);
if (!size.defined()) {
// Size could not be determined; therefore, it is not possible to
// determine the arithmetic and memory costs.
return Cost();
}
// If there is nothing to be inlined, use the pre-computed function cost.
Cost cost = inlines.empty() ? get_element(func_cost, func)[stage] : get_func_stage_cost(curr_f, stage, inlines);
if (!cost.defined()) {
return Cost();
}
return Cost(simplify(size * cost.arith), simplify(size * cost.memory));
}
Cost RegionCosts::stage_region_cost(const string &func, int stage, const Box ®ion,
const set<string> &inlines) {
Function curr_f = get_element(env, func);
DimBounds pure_bounds;
const vector<string> &args = curr_f.args();
internal_assert(args.size() == region.size());
for (size_t d = 0; d < args.size(); d++) {
pure_bounds.emplace(args[d], region[d]);
}
DimBounds stage_bounds = get_stage_bounds(curr_f, stage, pure_bounds);
return stage_region_cost(func, stage, stage_bounds, inlines);
}
Cost RegionCosts::region_cost(const string &func, const Box ®ion, const set<string> &inlines) {
Function curr_f = get_element(env, func);
Cost region_cost(0, 0);
int num_stages = curr_f.updates().size() + 1;
for (int s = 0; s < num_stages; s++) {
Cost stage_cost = stage_region_cost(func, s, region, inlines);
if (!stage_cost.defined()) {
return Cost();
} else {
region_cost.arith += stage_cost.arith;
region_cost.memory += stage_cost.memory;
}
}
internal_assert(region_cost.defined());
region_cost.simplify();
return region_cost;
}
Cost RegionCosts::region_cost(const map<string, Box> ®ions, const set<string> &inlines) {
Cost total_cost(0, 0);
for (const auto &f : regions) {
// The cost for pure inlined functions will be accounted in the
// consumer of the inlined function so they should be skipped.
if (inlines.find(f.first) != inlines.end()) {
internal_assert(get_element(env, f.first).is_pure());
continue;
}
Cost cost = region_cost(f.first, f.second, inlines);
if (!cost.defined()) {
return Cost();
} else {
total_cost.arith += cost.arith;
total_cost.memory += cost.memory;
}
}
internal_assert(total_cost.defined());
total_cost.simplify();
return total_cost;
}
map<string, Expr>
RegionCosts::stage_detailed_load_costs(const string &func, int stage,
const set<string> &inlines) {
map<string, Expr> load_costs;
Function curr_f = get_element(env, func);
if (curr_f.has_extern_definition()) {
// TODO(psuriana): We need a better cost for extern function
// load_costs.emplace(func, Int(64).max());
load_costs.emplace(func, Expr());
} else {
Definition def = get_stage_definition(curr_f, stage);
for (const auto &e : def.values()) {
Expr inlined_expr = perform_inline(e, env, inlines, order);
inlined_expr = simplify(inlined_expr);
map<string, Expr> expr_load_costs = compute_expr_detailed_byte_loads(inlined_expr);
combine_load_costs(load_costs, expr_load_costs);
auto iter = load_costs.find(func);
if (iter != load_costs.end()) {
internal_assert(iter->second.defined());
iter->second = simplify(iter->second + e.type().bytes());
} else {
load_costs.emplace(func, make_const(Int(64), e.type().bytes()));
}
}
}
return load_costs;
}
map<string, Expr>
RegionCosts::stage_detailed_load_costs(const string &func, int stage,
DimBounds &bounds,
const set<string> &inlines) {
Function curr_f = get_element(env, func);
Box stage_region;
const vector<Dim> &dims = get_stage_dims(curr_f, stage);
for (int d = 0; d < (int)dims.size() - 1; d++) {
stage_region.push_back(get_element(bounds, dims[d].var));
}
map<string, Expr> load_costs = stage_detailed_load_costs(func, stage, inlines);
Expr size = box_size(stage_region);
for (auto &kv : load_costs) {
if (!kv.second.defined()) {
continue;
} else if (!size.defined()) {
kv.second = Expr();
} else {
kv.second = simplify(kv.second * size);
}
}
return load_costs;
}
map<string, Expr>
RegionCosts::detailed_load_costs(const string &func, const Box ®ion,
const set<string> &inlines) {
Function curr_f = get_element(env, func);
map<string, Expr> load_costs;
int num_stages = curr_f.updates().size() + 1;
DimBounds pure_bounds;
const vector<string> &args = curr_f.args();
internal_assert(args.size() == region.size());
for (size_t d = 0; d < args.size(); d++) {
pure_bounds.emplace(args[d], region[d]);
}
vector<DimBounds> stage_bounds = get_stage_bounds(curr_f, pure_bounds);
for (int s = 0; s < num_stages; s++) {
map<string, Expr> stage_load_costs = stage_detailed_load_costs(func, s, inlines);
Box stage_region;
const vector<Dim> &dims = get_stage_dims(curr_f, s);
for (int d = 0; d < (int)dims.size() - 1; d++) {
stage_region.push_back(get_element(stage_bounds[s], dims[d].var));
}
Expr size = box_size(stage_region);
for (auto &kv : stage_load_costs) {
if (!kv.second.defined()) {
continue;
} else if (!size.defined()) {
kv.second = Expr();
} else {
kv.second = simplify(kv.second * size);
}
}
combine_load_costs(load_costs, stage_load_costs);
}
return load_costs;
}
map<string, Expr>
RegionCosts::detailed_load_costs(const map<string, Box> ®ions,
const set<string> &inlines) {
map<string, Expr> load_costs;
for (const auto &r : regions) {
// The cost for pure inlined functions will be accounted in the
// consumer of the inlined function so they should be skipped.
if (inlines.find(r.first) != inlines.end()) {
internal_assert(get_element(env, r.first).is_pure());
continue;
}
map<string, Expr> partial_load_costs = detailed_load_costs(r.first, r.second, inlines);
combine_load_costs(load_costs, partial_load_costs);
}
return load_costs;
}
Cost RegionCosts::get_func_stage_cost(const Function &f, int stage,
const set<string> &inlines) const {
if (f.has_extern_definition()) {
return Cost();
}
Definition def = get_stage_definition(f, stage);
Cost cost(0, 0);
for (const auto &e : def.values()) {
Expr inlined_expr = perform_inline(e, env, inlines, order);
inlined_expr = simplify(inlined_expr);
Cost expr_cost = compute_expr_cost(inlined_expr);
internal_assert(expr_cost.defined());
cost.arith += expr_cost.arith;
cost.memory += expr_cost.memory;
// Accounting for the store
cost.memory += e.type().bytes();
cost.arith += 1;
}
if (!f.is_pure()) {
for (const auto &arg : def.args()) {
Expr inlined_arg = perform_inline(arg, env, inlines, order);
inlined_arg = simplify(inlined_arg);
Cost expr_cost = compute_expr_cost(inlined_arg);
internal_assert(expr_cost.defined());
cost.arith += expr_cost.arith;
cost.memory += expr_cost.memory;
}
}
cost.simplify();
return cost;
}
vector<Cost> RegionCosts::get_func_cost(const Function &f, const set<string> &inlines) {
if (f.has_extern_definition()) {
return {Cost()};
}
size_t num_stages = f.updates().size() + 1;
vector<Cost> func_costs;
func_costs.reserve(num_stages);
for (size_t s = 0; s < num_stages; s++) {
func_costs.push_back(get_func_stage_cost(f, s, inlines));
}
return func_costs;
}
Expr RegionCosts::region_size(const string &func, const Box ®ion) {
const Function &f = get_element(env, func);
Expr size = box_size(region);
if (!size.defined()) {
return Expr();
}
Expr size_per_ele = get_func_value_size(f);
internal_assert(size_per_ele.defined());
return simplify(size * size_per_ele);
}
Expr RegionCosts::region_footprint(const map<string, Box> ®ions,
const set<string> &inlined) {
map<string, int> num_consumers;
for (const auto &f : regions) {
num_consumers[f.first] = 0;
}
for (const auto &f : regions) {
map<string, Function> prods = find_direct_calls(get_element(env, f.first));
for (const auto &p : prods) {
auto iter = num_consumers.find(p.first);
if (iter != num_consumers.end()) {
iter->second += 1;
}
}
}
vector<Function> outs;
for (const auto &f : num_consumers) {
if (f.second == 0) {
outs.push_back(get_element(env, f.first));
}
}
vector<string> top_order = topological_order(outs, env);
Expr working_set_size = make_zero(Int(64));
Expr curr_size = make_zero(Int(64));
map<string, Expr> func_sizes;
for (const auto &f : regions) {
// Inlined functions do not have allocations
bool is_inlined = inlined.find(f.first) != inlined.end();
Expr size = is_inlined ? make_zero(Int(64)) : region_size(f.first, f.second);
if (!size.defined()) {
return Expr();
} else {
func_sizes.emplace(f.first, size);
}
}
for (const auto &f : top_order) {
if (regions.find(f) != regions.end()) {
curr_size += get_element(func_sizes, f);
}
working_set_size = max(curr_size, working_set_size);
map<string, Function> prods = find_direct_calls(get_element(env, f));
for (const auto &p : prods) {
auto iter = num_consumers.find(p.first);
if (iter != num_consumers.end()) {
iter->second -= 1;
if (iter->second == 0) {
curr_size -= get_element(func_sizes, p.first);
internal_assert(!can_prove(curr_size < 0));
}
}
}
}
return simplify(working_set_size);
}
Expr RegionCosts::input_region_size(const string &input, const Box ®ion) {
Expr size = box_size(region);
if (!size.defined()) {
return Expr();
}
Expr size_per_ele = make_const(Int(64), get_element(inputs, input).bytes());
internal_assert(size_per_ele.defined());
return simplify(size * size_per_ele);
}
Expr RegionCosts::input_region_size(const map<string, Box> &input_regions) {
Expr total_size = make_zero(Int(64));
for (const auto ® : input_regions) {
Expr size = input_region_size(reg.first, reg.second);
if (!size.defined()) {
return Expr();
} else {
total_size += size;
}
}
return simplify(total_size);
}
void RegionCosts::disp_func_costs() {
debug(0) << "===========================\n"
<< "Pipeline per element costs:\n"
<< "===========================\n";
for (const auto &kv : env) {
int stage = 0;
for (const auto &cost : func_cost[kv.first]) {
if (kv.second.has_extern_definition()) {
debug(0) << "Extern func\n";
} else {
Definition def = get_stage_definition(kv.second, stage);
for (const auto &e : def.values()) {
debug(0) << simplify(e) << "\n";
}
}
debug(0) << "(" << kv.first << ", " << stage << ") -> ("
<< cost.arith << ", " << cost.memory << ")\n";
stage++;
}
}
debug(0) << "===========================\n";
}
bool is_func_trivial_to_inline(const Function &func) {
if (!func.can_be_inlined()) {
return false;
}
// For multi-dimensional tuple, we want to take the max over the arithmetic
// and memory cost separately for conservative estimate.
Cost inline_cost(0, 0);
for (const auto &val : func.values()) {
Cost cost = compute_expr_cost(val);
internal_assert(cost.defined());
inline_cost.arith = max(cost.arith, inline_cost.arith);
inline_cost.memory = max(cost.memory, inline_cost.memory);
}
// Compute the cost if we were to call the function instead of inline it
Cost call_cost(1, 0);
for (const auto &type : func.output_types()) {
call_cost.memory = max(type.bytes(), call_cost.memory);
}
Expr is_trivial = (call_cost.arith + call_cost.memory) >= (inline_cost.arith + inline_cost.memory);
return can_prove(is_trivial);
}
void Cost::simplify() {
if (arith.defined()) {
arith = Internal::simplify(arith);
}
if (memory.defined()) {
memory = Internal::simplify(memory);
}
}
} // namespace Internal
} // namespace Halide