Skip to content

Commit 307548c

Browse files
Krzysztof Rymskicopybara-github
authored andcommitted
Remove UB when used with not aligned data
Use uint8_t SIMD loads followed by type-level BitCast to satisfy C++ alignment constraints for UBSan. PiperOrigin-RevId: 921518179
1 parent 53bcd7a commit 307548c

1 file changed

Lines changed: 72 additions & 25 deletions

File tree

compression/compress-inl.h

Lines changed: 72 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,16 @@ struct CompressTraits<float> {
9292
const hn::Repartition<float, decltype(dbf16)> df;
9393
using VF = hn::Vec<decltype(df)>;
9494
const size_t NF = hn::Lanes(df);
95-
const VF f0 = hn::LoadU(df, packed.ptr + packed_ofs + 0 * NF);
96-
const VF f1 = hn::LoadU(df, packed.ptr + packed_ofs + 1 * NF);
97-
const VF f2 = hn::LoadU(df, packed.ptr + packed_ofs + 2 * NF);
98-
const VF f3 = hn::LoadU(df, packed.ptr + packed_ofs + 3 * NF);
95+
const hn::Repartition<uint8_t, decltype(df)> du8;
96+
const uint8_t* src_bytes = reinterpret_cast<const uint8_t*>(packed.ptr);
97+
const VF f0 = hn::BitCast(
98+
df, hn::LoadU(du8, src_bytes + (packed_ofs + 0 * NF) * sizeof(Packed)));
99+
const VF f1 = hn::BitCast(
100+
df, hn::LoadU(du8, src_bytes + (packed_ofs + 1 * NF) * sizeof(Packed)));
101+
const VF f2 = hn::BitCast(
102+
df, hn::LoadU(du8, src_bytes + (packed_ofs + 2 * NF) * sizeof(Packed)));
103+
const VF f3 = hn::BitCast(
104+
df, hn::LoadU(du8, src_bytes + (packed_ofs + 3 * NF) * sizeof(Packed)));
99105
raw0 = hn::OrderedDemote2To(dbf16, f0, f1);
100106
raw1 = hn::OrderedDemote2To(dbf16, f2, f3);
101107
}
@@ -104,8 +110,12 @@ struct CompressTraits<float> {
104110
static HWY_INLINE void Load2(DF df, const PackedSpan<const Packed>& packed,
105111
const size_t packed_ofs, VF& raw0, VF& raw1) {
106112
const size_t N = hn::Lanes(df);
107-
raw0 = hn::LoadU(df, packed.ptr + packed_ofs);
108-
raw1 = hn::LoadU(df, packed.ptr + packed_ofs + N);
113+
const hn::Repartition<uint8_t, DF> du8;
114+
const uint8_t* src_bytes = reinterpret_cast<const uint8_t*>(packed.ptr);
115+
raw0 = hn::BitCast(df,
116+
hn::LoadU(du8, src_bytes + packed_ofs * sizeof(Packed)));
117+
raw1 = hn::BitCast(
118+
df, hn::LoadU(du8, src_bytes + (packed_ofs + N) * sizeof(Packed)));
109119
}
110120

111121
template <class DD, HWY_IF_F64_D(DD), class VD = hn::Vec<DD>>
@@ -114,9 +124,12 @@ struct CompressTraits<float> {
114124
const hn::Rebind<float, DD> df;
115125
using VF = hn::Vec<decltype(df)>;
116126
const size_t NF = hn::Lanes(df);
117-
// Two half loads are likely cheaper than one full + UpperHalf.
118-
const VF f0 = hn::LoadU(df, packed.ptr + packed_ofs + 0 * NF);
119-
const VF f1 = hn::LoadU(df, packed.ptr + packed_ofs + 1 * NF);
127+
const hn::Repartition<uint8_t, decltype(df)> du8;
128+
const uint8_t* src_bytes = reinterpret_cast<const uint8_t*>(packed.ptr);
129+
const VF f0 = hn::BitCast(
130+
df, hn::LoadU(du8, src_bytes + (packed_ofs + 0 * NF) * sizeof(Packed)));
131+
const VF f1 = hn::BitCast(
132+
df, hn::LoadU(du8, src_bytes + (packed_ofs + 1 * NF) * sizeof(Packed)));
120133
raw0 = hn::PromoteTo(dd, f0);
121134
raw1 = hn::PromoteTo(dd, f1);
122135
}
@@ -128,21 +141,31 @@ struct CompressTraits<float> {
128141
const hn::Repartition<float, decltype(dbf)> df;
129142
using VF = hn::Vec<decltype(df)>;
130143
const size_t NF = hn::Lanes(df);
144+
const hn::Repartition<uint8_t, decltype(dbf)> du8;
145+
const uint8_t* src_bytes = reinterpret_cast<const uint8_t*>(packed.ptr);
131146

132147
size_t i = 0;
133148
if (num >= 2 * NF) {
134149
for (; i <= num - 2 * NF; i += 2 * NF) {
135-
const VF f0 = hn::LoadU(df, packed.ptr + packed_ofs + i);
136-
const VF f1 = hn::LoadU(df, packed.ptr + packed_ofs + i + NF);
150+
const VF f0 = hn::BitCast(
151+
df, hn::LoadU(du8, src_bytes + (packed_ofs + i) * sizeof(Packed)));
152+
const VF f1 = hn::BitCast(
153+
df,
154+
hn::LoadU(du8, src_bytes + (packed_ofs + i + NF) * sizeof(Packed)));
137155
hn::StoreU(hn::OrderedDemote2To(dbf, f0, f1), dbf, raw + i);
138156
}
139157
}
140158
const size_t remaining = num - i;
141159
HWY_DASSERT(remaining < 2 * NF);
142160
if (HWY_UNLIKELY(remaining != 0)) {
143161
const size_t remaining2 = remaining - HWY_MIN(remaining, NF);
144-
const VF f0 = hn::LoadN(df, packed.ptr + packed_ofs + i, remaining);
145-
const VF f1 = hn::LoadN(df, packed.ptr + packed_ofs + i + NF, remaining2);
162+
const VF f0 = hn::BitCast(
163+
df, hn::LoadN(du8, src_bytes + (packed_ofs + i) * sizeof(Packed),
164+
HWY_MIN(remaining, NF) * sizeof(Packed)));
165+
const VF f1 = hn::BitCast(
166+
df, hn::LoadN(du8,
167+
src_bytes + (packed_ofs + i + NF) * sizeof(Packed),
168+
remaining2 * sizeof(Packed)));
146169
hn::StoreU(hn::OrderedDemote2To(dbf, f0, f1), dbf, raw + i);
147170
}
148171
}
@@ -153,18 +176,23 @@ struct CompressTraits<float> {
153176
float* HWY_RESTRICT raw, size_t num) {
154177
using VF = hn::Vec<decltype(df)>;
155178
const size_t NF = hn::Lanes(df);
179+
const hn::Repartition<uint8_t, DF> du8;
180+
const uint8_t* src_bytes = reinterpret_cast<const uint8_t*>(packed.ptr);
156181

157182
size_t i = 0;
158183
if (num >= NF) {
159184
for (; i <= num - NF; i += NF) {
160-
const VF vf = hn::LoadU(df, packed.ptr + packed_ofs + i);
185+
const VF vf = hn::BitCast(
186+
df, hn::LoadU(du8, src_bytes + (packed_ofs + i) * sizeof(Packed)));
161187
hn::StoreU(vf, df, raw + i);
162188
}
163189
}
164190
const size_t remaining = num - i;
165191
HWY_DASSERT(remaining < NF);
166192
if (HWY_UNLIKELY(remaining != 0)) {
167-
const VF vf = hn::LoadN(df, packed.ptr + packed_ofs + i, remaining);
193+
const VF vf = hn::BitCast(
194+
df, hn::LoadN(du8, src_bytes + (packed_ofs + i) * sizeof(Packed),
195+
remaining * sizeof(Packed)));
168196
hn::StoreU(vf, df, raw + i); // adds zero padding
169197
}
170198
}
@@ -176,18 +204,23 @@ struct CompressTraits<float> {
176204
const hn::Rebind<float, DD> df;
177205
using VF = hn::Vec<decltype(df)>;
178206
const size_t ND = hn::Lanes(dd);
207+
const hn::Repartition<uint8_t, decltype(df)> du8;
208+
const uint8_t* src_bytes = reinterpret_cast<const uint8_t*>(packed.ptr);
179209

180210
size_t i = 0;
181211
if (num >= ND) {
182212
for (; i <= num - ND; i += ND) {
183-
const VF vf = hn::LoadU(df, packed.ptr + packed_ofs + i);
213+
const VF vf = hn::BitCast(
214+
df, hn::LoadU(du8, src_bytes + (packed_ofs + i) * sizeof(Packed)));
184215
hn::StoreU(hn::PromoteTo(dd, vf), dd, raw + i);
185216
}
186217
}
187218
const size_t remaining = num - i;
188219
HWY_DASSERT(remaining < ND);
189220
if (HWY_UNLIKELY(remaining != 0)) {
190-
const VF vf = hn::LoadN(df, packed.ptr + packed_ofs + i, remaining);
221+
const VF vf = hn::BitCast(
222+
df, hn::LoadN(du8, src_bytes + (packed_ofs + i) * sizeof(Packed),
223+
remaining * sizeof(Packed)));
191224
hn::StoreU(hn::PromoteTo(dd, vf), dd, raw + i); // adds zero padding
192225
}
193226
}
@@ -265,9 +298,13 @@ struct CompressTraits<BF16> {
265298
const PackedSpan<const Packed>& packed,
266299
const size_t packed_ofs, hn::Vec<DBF16>& raw0,
267300
hn::Vec<DBF16>& raw1) {
301+
const hn::Repartition<uint8_t, DBF16> du8;
268302
const size_t N16 = hn::Lanes(dbf16);
269-
raw0 = hn::LoadU(dbf16, packed.ptr + packed_ofs);
270-
raw1 = hn::LoadU(dbf16, packed.ptr + packed_ofs + N16);
303+
const uint8_t* src_bytes = reinterpret_cast<const uint8_t*>(packed.ptr);
304+
raw0 = hn::BitCast(dbf16,
305+
hn::LoadU(du8, src_bytes + packed_ofs * sizeof(Packed)));
306+
raw1 = hn::BitCast(
307+
dbf16, hn::LoadU(du8, src_bytes + (packed_ofs + N16) * sizeof(Packed)));
271308
}
272309

273310
template <class DF, HWY_IF_F32_D(DF)>
@@ -276,7 +313,10 @@ struct CompressTraits<BF16> {
276313
hn::Vec<DF>& raw1) {
277314
const hn::Repartition<BF16, decltype(df)> dbf;
278315
using VBF = hn::Vec<decltype(dbf)>;
279-
const VBF packed0 = hn::LoadU(dbf, packed.ptr + packed_ofs);
316+
const hn::Repartition<uint8_t, decltype(df)> du8;
317+
const uint8_t* src_bytes = reinterpret_cast<const uint8_t*>(packed.ptr);
318+
const VBF packed0 = hn::BitCast(
319+
dbf, hn::LoadU(du8, src_bytes + packed_ofs * sizeof(Packed)));
280320
raw0 = hn::PromoteLowerTo(df, packed0);
281321
raw1 = hn::PromoteUpperTo(df, packed0);
282322
}
@@ -287,20 +327,24 @@ struct CompressTraits<BF16> {
287327
BF16* HWY_RESTRICT raw, size_t num) {
288328
using VBF = hn::Vec<decltype(dbf)>;
289329
const size_t N16 = hn::Lanes(dbf);
330+
const hn::Repartition<uint8_t, DBF> du8;
331+
const uint8_t* src_bytes = reinterpret_cast<const uint8_t*>(packed.ptr);
290332

291333
size_t i = 0;
292334
if (num >= N16) {
293335
for (; i <= num - N16; i += N16) {
294-
const VBF packed0 = hn::LoadU(dbf, packed.ptr + packed_ofs + i);
336+
const VBF packed0 = hn::BitCast(
337+
dbf, hn::LoadU(du8, src_bytes + (packed_ofs + i) * sizeof(Packed)));
295338
hn::StoreU(packed0, dbf, raw + i);
296339
}
297340
}
298341

299342
const size_t remaining = num - i;
300343
HWY_DASSERT(remaining < N16);
301344
if (HWY_UNLIKELY(remaining != 0)) {
302-
const VBF packed0 =
303-
hn::LoadN(dbf, packed.ptr + packed_ofs + i, remaining);
345+
const VBF packed0 = hn::BitCast(
346+
dbf, hn::LoadN(du8, src_bytes + (packed_ofs + i) * sizeof(Packed),
347+
remaining * sizeof(Packed)));
304348
hn::StoreU(packed0, dbf, raw + i);
305349
}
306350
}
@@ -363,8 +407,11 @@ struct CompressTraits<BF16> {
363407
const size_t remaining = num - i;
364408
HWY_DASSERT(remaining < 2 * NF);
365409
if (HWY_UNLIKELY(remaining != 0)) {
366-
const VBF packed0 =
367-
hn::LoadN(dbf, packed.ptr + packed_ofs + i, remaining);
410+
const hn::Repartition<uint8_t, decltype(dbf)> du8;
411+
const uint8_t* src_bytes = reinterpret_cast<const uint8_t*>(packed.ptr);
412+
const VBF packed0 = hn::BitCast(
413+
dbf, hn::LoadN(du8, src_bytes + (packed_ofs + i) * sizeof(Packed),
414+
remaining * sizeof(Packed)));
368415
const VF raw0 = hn::PromoteLowerTo(df, packed0);
369416
const VF raw1 = hn::PromoteUpperTo(df, packed0);
370417
// If at most one vector, the first store adds zero padding. Check before

0 commit comments

Comments
 (0)