Skip to content

Commit e954baf

Browse files
Krzysztof Rymskicopybara-github
authored andcommitted
Remove UB when used with not aligned data
Use uint8_t SIMD loads followed by type-level BitCast to satisfy C++ alignment constraints for UBSan. Fix tile attention test PiperOrigin-RevId: 921518179
1 parent e58e56c commit e954baf

1 file changed

Lines changed: 52 additions & 28 deletions

File tree

compression/compress-inl.h

Lines changed: 52 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,26 @@ namespace gcpp {
5656
namespace HWY_NAMESPACE {
5757
namespace hn = hwy::HWY_NAMESPACE;
5858

59+
template <class D, typename Packed>
60+
static HWY_INLINE hn::Vec<D> LoadNonElementAligned(
61+
D d, const Packed* HWY_RESTRICT ptr, size_t offset_in_packed) {
62+
const hn::Repartition<uint8_t, D> du8;
63+
const uint8_t* src_bytes = reinterpret_cast<const uint8_t*>(ptr);
64+
return hn::BitCast(
65+
d, hn::LoadU(du8, src_bytes + offset_in_packed * sizeof(Packed)));
66+
}
67+
68+
template <class D, typename Packed>
69+
static HWY_INLINE hn::Vec<D> LoadNNonElementAligned(
70+
D d, const Packed* HWY_RESTRICT ptr, size_t offset_in_packed,
71+
size_t num_packed) {
72+
const hn::Repartition<uint8_t, D> du8;
73+
const uint8_t* src_bytes = reinterpret_cast<const uint8_t*>(ptr);
74+
return hn::BitCast(
75+
d, hn::LoadN(du8, src_bytes + offset_in_packed * sizeof(Packed),
76+
num_packed * sizeof(Packed)));
77+
}
78+
5979
// Enables generic code independent of compression type.
6080
template <typename T> // primary, must specialize
6181
struct CompressTraits {};
@@ -92,10 +112,10 @@ struct CompressTraits<float> {
92112
const hn::Repartition<float, decltype(dbf16)> df;
93113
using VF = hn::Vec<decltype(df)>;
94114
const size_t NF = hn::Lanes(df);
95-
const VF f0 = hn::LoadU(df, packed.ptr + packed_ofs + 0 * NF);
96-
const VF f1 = hn::LoadU(df, packed.ptr + packed_ofs + 1 * NF);
97-
const VF f2 = hn::LoadU(df, packed.ptr + packed_ofs + 2 * NF);
98-
const VF f3 = hn::LoadU(df, packed.ptr + packed_ofs + 3 * NF);
115+
const VF f0 = LoadNonElementAligned(df, packed.ptr, packed_ofs + 0 * NF);
116+
const VF f1 = LoadNonElementAligned(df, packed.ptr, packed_ofs + 1 * NF);
117+
const VF f2 = LoadNonElementAligned(df, packed.ptr, packed_ofs + 2 * NF);
118+
const VF f3 = LoadNonElementAligned(df, packed.ptr, packed_ofs + 3 * NF);
99119
raw0 = hn::OrderedDemote2To(dbf16, f0, f1);
100120
raw1 = hn::OrderedDemote2To(dbf16, f2, f3);
101121
}
@@ -104,8 +124,8 @@ struct CompressTraits<float> {
104124
static HWY_INLINE void Load2(DF df, const PackedSpan<const Packed>& packed,
105125
const size_t packed_ofs, VF& raw0, VF& raw1) {
106126
const size_t N = hn::Lanes(df);
107-
raw0 = hn::LoadU(df, packed.ptr + packed_ofs);
108-
raw1 = hn::LoadU(df, packed.ptr + packed_ofs + N);
127+
raw0 = LoadNonElementAligned(df, packed.ptr, packed_ofs);
128+
raw1 = LoadNonElementAligned(df, packed.ptr, packed_ofs + N);
109129
}
110130

111131
template <class DD, HWY_IF_F64_D(DD), class VD = hn::Vec<DD>>
@@ -114,9 +134,8 @@ struct CompressTraits<float> {
114134
const hn::Rebind<float, DD> df;
115135
using VF = hn::Vec<decltype(df)>;
116136
const size_t NF = hn::Lanes(df);
117-
// Two half loads are likely cheaper than one full + UpperHalf.
118-
const VF f0 = hn::LoadU(df, packed.ptr + packed_ofs + 0 * NF);
119-
const VF f1 = hn::LoadU(df, packed.ptr + packed_ofs + 1 * NF);
137+
const VF f0 = LoadNonElementAligned(df, packed.ptr, packed_ofs + 0 * NF);
138+
const VF f1 = LoadNonElementAligned(df, packed.ptr, packed_ofs + 1 * NF);
120139
raw0 = hn::PromoteTo(dd, f0);
121140
raw1 = hn::PromoteTo(dd, f1);
122141
}
@@ -132,17 +151,22 @@ struct CompressTraits<float> {
132151
size_t i = 0;
133152
if (num >= 2 * NF) {
134153
for (; i <= num - 2 * NF; i += 2 * NF) {
135-
const VF f0 = hn::LoadU(df, packed.ptr + packed_ofs + i);
136-
const VF f1 = hn::LoadU(df, packed.ptr + packed_ofs + i + NF);
154+
const VF f0 = LoadNonElementAligned(df, packed.ptr, packed_ofs + i);
155+
const VF f1 =
156+
LoadNonElementAligned(df, packed.ptr, packed_ofs + i + NF);
137157
hn::StoreU(hn::OrderedDemote2To(dbf, f0, f1), dbf, raw + i);
138158
}
139159
}
140160
const size_t remaining = num - i;
141161
HWY_DASSERT(remaining < 2 * NF);
142162
if (HWY_UNLIKELY(remaining != 0)) {
143-
const size_t remaining2 = remaining - HWY_MIN(remaining, NF);
144-
const VF f0 = hn::LoadN(df, packed.ptr + packed_ofs + i, remaining);
145-
const VF f1 = hn::LoadN(df, packed.ptr + packed_ofs + i + NF, remaining2);
163+
const VF f0 =
164+
LoadNNonElementAligned(df, packed.ptr, packed_ofs + i, remaining);
165+
VF f1 = hn::Zero(df);
166+
if (remaining > NF) {
167+
f1 = LoadNNonElementAligned(df, packed.ptr, packed_ofs + i + NF,
168+
remaining - NF);
169+
}
146170
hn::StoreU(hn::OrderedDemote2To(dbf, f0, f1), dbf, raw + i);
147171
}
148172
}
@@ -157,14 +181,14 @@ struct CompressTraits<float> {
157181
size_t i = 0;
158182
if (num >= NF) {
159183
for (; i <= num - NF; i += NF) {
160-
const VF vf = hn::LoadU(df, packed.ptr + packed_ofs + i);
184+
const VF vf = LoadNonElementAligned(df, packed.ptr, packed_ofs + i);
161185
hn::StoreU(vf, df, raw + i);
162186
}
163187
}
164188
const size_t remaining = num - i;
165189
HWY_DASSERT(remaining < NF);
166190
if (HWY_UNLIKELY(remaining != 0)) {
167-
const VF vf = hn::LoadN(df, packed.ptr + packed_ofs + i, remaining);
191+
const VF vf = LoadNNonElementAligned(df, packed.ptr, packed_ofs + i, remaining);
168192
hn::StoreU(vf, df, raw + i); // adds zero padding
169193
}
170194
}
@@ -180,14 +204,14 @@ struct CompressTraits<float> {
180204
size_t i = 0;
181205
if (num >= ND) {
182206
for (; i <= num - ND; i += ND) {
183-
const VF vf = hn::LoadU(df, packed.ptr + packed_ofs + i);
207+
const VF vf = LoadNonElementAligned(df, packed.ptr, packed_ofs + i);
184208
hn::StoreU(hn::PromoteTo(dd, vf), dd, raw + i);
185209
}
186210
}
187211
const size_t remaining = num - i;
188212
HWY_DASSERT(remaining < ND);
189213
if (HWY_UNLIKELY(remaining != 0)) {
190-
const VF vf = hn::LoadN(df, packed.ptr + packed_ofs + i, remaining);
214+
const VF vf = LoadNNonElementAligned(df, packed.ptr, packed_ofs + i, remaining);
191215
hn::StoreU(hn::PromoteTo(dd, vf), dd, raw + i); // adds zero padding
192216
}
193217
}
@@ -231,8 +255,10 @@ struct CompressTraits<BF16> {
231255
HWY_DASSERT(remaining < 2 * NF);
232256
if (remaining != 0) {
233257
const VF raw0 = hn::LoadN(df, raw + i, remaining);
234-
const size_t remaining1 = remaining - HWY_MIN(remaining, NF);
235-
const VF raw1 = hn::LoadN(df, raw + i + NF, remaining1);
258+
VF raw1 = hn::Zero(df);
259+
if (remaining > NF) {
260+
raw1 = hn::LoadN(df, raw + i + NF, remaining - NF);
261+
}
236262

237263
hn::StoreN(hn::OrderedDemote2To(dbf, raw0, raw1), dbf,
238264
packed.ptr + packed_ofs + i, remaining);
@@ -266,8 +292,8 @@ struct CompressTraits<BF16> {
266292
const size_t packed_ofs, hn::Vec<DBF16>& raw0,
267293
hn::Vec<DBF16>& raw1) {
268294
const size_t N16 = hn::Lanes(dbf16);
269-
raw0 = hn::LoadU(dbf16, packed.ptr + packed_ofs);
270-
raw1 = hn::LoadU(dbf16, packed.ptr + packed_ofs + N16);
295+
raw0 = LoadNonElementAligned(dbf16, packed.ptr, packed_ofs);
296+
raw1 = LoadNonElementAligned(dbf16, packed.ptr, packed_ofs + N16);
271297
}
272298

273299
template <class DF, HWY_IF_F32_D(DF)>
@@ -276,7 +302,7 @@ struct CompressTraits<BF16> {
276302
hn::Vec<DF>& raw1) {
277303
const hn::Repartition<BF16, decltype(df)> dbf;
278304
using VBF = hn::Vec<decltype(dbf)>;
279-
const VBF packed0 = hn::LoadU(dbf, packed.ptr + packed_ofs);
305+
const VBF packed0 = LoadNonElementAligned(dbf, packed.ptr, packed_ofs);
280306
raw0 = hn::PromoteLowerTo(df, packed0);
281307
raw1 = hn::PromoteUpperTo(df, packed0);
282308
}
@@ -291,16 +317,15 @@ struct CompressTraits<BF16> {
291317
size_t i = 0;
292318
if (num >= N16) {
293319
for (; i <= num - N16; i += N16) {
294-
const VBF packed0 = hn::LoadU(dbf, packed.ptr + packed_ofs + i);
320+
const VBF packed0 = LoadNonElementAligned(dbf, packed.ptr, packed_ofs + i);
295321
hn::StoreU(packed0, dbf, raw + i);
296322
}
297323
}
298324

299325
const size_t remaining = num - i;
300326
HWY_DASSERT(remaining < N16);
301327
if (HWY_UNLIKELY(remaining != 0)) {
302-
const VBF packed0 =
303-
hn::LoadN(dbf, packed.ptr + packed_ofs + i, remaining);
328+
const VBF packed0 = LoadNNonElementAligned(dbf, packed.ptr, packed_ofs + i, remaining);
304329
hn::StoreU(packed0, dbf, raw + i);
305330
}
306331
}
@@ -363,8 +388,7 @@ struct CompressTraits<BF16> {
363388
const size_t remaining = num - i;
364389
HWY_DASSERT(remaining < 2 * NF);
365390
if (HWY_UNLIKELY(remaining != 0)) {
366-
const VBF packed0 =
367-
hn::LoadN(dbf, packed.ptr + packed_ofs + i, remaining);
391+
const VBF packed0 = LoadNNonElementAligned(dbf, packed.ptr, packed_ofs + i, remaining);
368392
const VF raw0 = hn::PromoteLowerTo(df, packed0);
369393
const VF raw1 = hn::PromoteUpperTo(df, packed0);
370394
// If at most one vector, the first store adds zero padding. Check before

0 commit comments

Comments
 (0)