-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Expand file tree
/
Copy pathAlignLoads.cpp
More file actions
172 lines (142 loc) · 6.71 KB
/
AlignLoads.cpp
File metadata and controls
172 lines (142 loc) · 6.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#include <algorithm>
#include "AlignLoads.h"
#include "HexagonAlignment.h"
#include "IRMutator.h"
#include "IROperator.h"
#include "ModulusRemainder.h"
#include "Scope.h"
#include "Simplify.h"
using std::vector;
namespace Halide {
namespace Internal {
namespace {
// This mutator attempts to rewrite unaligned or strided loads to
// sequences of aligned loads by loading aligned vectors that cover
// the original unaligned load, and then slicing or shuffling the
// intended vector out of the aligned vector.
class AlignLoads : public IRMutator {
public:
AlignLoads(int alignment, int min_bytes)
: alignment_analyzer(alignment), required_alignment(alignment), min_bytes_to_align(min_bytes) {
}
private:
HexagonAlignmentAnalyzer alignment_analyzer;
// Loads and stores should ideally be aligned to the vector width in bytes.
int required_alignment;
// Minimum size of load to align.
int min_bytes_to_align;
using IRMutator::visit;
// Rewrite a load to have a new index, updating the type if necessary.
Expr make_load(const Load *load, const Expr &index, ModulusRemainder alignment) {
internal_assert(is_const_one(load->predicate)) << "Load should not be predicated.\n";
return mutate(Load::make(load->type.with_lanes(index.type().lanes()), load->name,
index, load->image, load->param,
const_true(index.type().lanes()),
alignment));
}
Expr visit(const Load *op) override {
if (!is_const_one(op->predicate)) {
// TODO(psuriana): Do nothing to predicated loads for now.
return IRMutator::visit(op);
}
if (!op->type.is_vector()) {
// Nothing to do for scalar loads.
return IRMutator::visit(op);
}
if (op->image.defined()) {
// We can't reason about the alignment of external images.
return IRMutator::visit(op);
}
if (required_alignment % op->type.bytes() != 0) {
return IRMutator::visit(op);
}
if (op->type.bytes() * op->type.lanes() <= min_bytes_to_align) {
// These can probably be treated as scalars instead.
return IRMutator::visit(op);
}
Expr index = mutate(op->index);
const Ramp *ramp = index.as<Ramp>();
auto const_stride = ramp ? as_const_int(ramp->stride) : std::nullopt;
if (!ramp || !const_stride) {
// We can't handle indirect loads, or loads with
// non-constant strides.
return IRMutator::visit(op);
}
if (!(*const_stride == 1 || *const_stride == 2 || *const_stride == 3 || *const_stride == 4)) {
// Handle ramps with stride 1, 2, 3 or 4 only.
return IRMutator::visit(op);
}
int64_t aligned_offset = 0;
bool is_aligned =
alignment_analyzer.is_aligned(op, &aligned_offset);
// We know the alignment_analyzer has been able to reason about alignment
// if the following is true.
bool known_alignment = is_aligned || (!is_aligned && aligned_offset != 0);
int lanes = ramp->lanes;
int native_lanes = required_alignment / op->type.bytes();
int stride = static_cast<int>(*const_stride);
if (stride != 1) {
internal_assert(stride >= 0);
// If we know the offset of this strided load is smaller
// than the stride, we can just make the load aligned now
// without requiring more vectors from the dense
// load. This makes loads like f(2*x + 1) into an aligned
// load of double length, with a single shuffle.
int shift = known_alignment && aligned_offset < stride ? aligned_offset : 0;
// Load a dense vector covering all of the addresses in the load.
Expr dense_base = simplify(ramp->base - shift);
ModulusRemainder alignment = op->alignment - shift;
Expr dense_index = Ramp::make(dense_base, 1, lanes * stride);
Expr dense = make_load(op, dense_index, alignment);
// Shuffle the dense load.
return Shuffle::make_slice(dense, shift, stride, lanes);
}
// We now have a dense vector load to deal with.
internal_assert(stride == 1);
if (lanes < native_lanes) {
// This load is smaller than a native vector. Load a
// native vector.
Expr ramp_base = ramp->base;
ModulusRemainder alignment = op->alignment;
int slice_offset = 0;
// If load is smaller than a native vector and can fully fit inside of it and offset is known,
// we can simply offset the native load and slice.
if (!is_aligned && aligned_offset != 0 && Int(32).can_represent(aligned_offset) && (aligned_offset + lanes <= native_lanes)) {
ramp_base = simplify(ramp_base - (int)aligned_offset);
alignment = alignment - aligned_offset;
slice_offset = aligned_offset;
}
Expr native_load = make_load(op, Ramp::make(ramp_base, 1, native_lanes), alignment);
// Slice the native load.
return Shuffle::make_slice(native_load, slice_offset, 1, lanes);
}
if (lanes > native_lanes) {
// This load is larger than a native vector. Load native
// vectors, and concatenate the results.
vector<Expr> slices;
for (int i = 0; i < lanes; i += native_lanes) {
int slice_lanes = std::min(native_lanes, lanes - i);
Expr slice_base = simplify(ramp->base + i);
ModulusRemainder alignment = op->alignment + i;
slices.push_back(make_load(op, Ramp::make(slice_base, 1, slice_lanes), alignment));
}
return Shuffle::make_concat(slices);
}
if (!is_aligned && aligned_offset != 0 && Int(32).can_represent(aligned_offset)) {
// We know the offset of this load from an aligned
// address. Rewrite this is an aligned load of two
// native vectors, followed by a shuffle.
Expr aligned_base = simplify(ramp->base - (int)aligned_offset);
ModulusRemainder alignment = op->alignment - (int)aligned_offset;
Expr aligned_load = make_load(op, Ramp::make(aligned_base, 1, lanes * 2), alignment);
return Shuffle::make_slice(aligned_load, (int)aligned_offset, 1, lanes);
}
return IRMutator::visit(op);
}
};
} // namespace
Stmt align_loads(const Stmt &s, int alignment, int min_bytes_to_align) {
return AlignLoads(alignment, min_bytes_to_align)(s);
}
} // namespace Internal
} // namespace Halide