@@ -92,10 +92,16 @@ struct CompressTraits<float> {
9292 const hn::Repartition<float , decltype (dbf16)> df;
9393 using VF = hn::Vec<decltype (df)>;
9494 const size_t NF = hn::Lanes (df);
95- const VF f0 = hn::LoadU (df, packed.ptr + packed_ofs + 0 * NF);
96- const VF f1 = hn::LoadU (df, packed.ptr + packed_ofs + 1 * NF);
97- const VF f2 = hn::LoadU (df, packed.ptr + packed_ofs + 2 * NF);
98- const VF f3 = hn::LoadU (df, packed.ptr + packed_ofs + 3 * NF);
95+ const hn::Repartition<uint8_t , decltype (df)> du8;
96+ const uint8_t * src_bytes = reinterpret_cast <const uint8_t *>(packed.ptr );
97+ const VF f0 = hn::BitCast (
98+ df, hn::LoadU (du8, src_bytes + (packed_ofs + 0 * NF) * sizeof (Packed)));
99+ const VF f1 = hn::BitCast (
100+ df, hn::LoadU (du8, src_bytes + (packed_ofs + 1 * NF) * sizeof (Packed)));
101+ const VF f2 = hn::BitCast (
102+ df, hn::LoadU (du8, src_bytes + (packed_ofs + 2 * NF) * sizeof (Packed)));
103+ const VF f3 = hn::BitCast (
104+ df, hn::LoadU (du8, src_bytes + (packed_ofs + 3 * NF) * sizeof (Packed)));
99105 raw0 = hn::OrderedDemote2To (dbf16, f0, f1);
100106 raw1 = hn::OrderedDemote2To (dbf16, f2, f3);
101107 }
@@ -104,8 +110,12 @@ struct CompressTraits<float> {
104110 static HWY_INLINE void Load2 (DF df, const PackedSpan<const Packed>& packed,
105111 const size_t packed_ofs, VF& raw0, VF& raw1) {
106112 const size_t N = hn::Lanes (df);
107- raw0 = hn::LoadU (df, packed.ptr + packed_ofs);
108- raw1 = hn::LoadU (df, packed.ptr + packed_ofs + N);
113+ const hn::Repartition<uint8_t , DF> du8;
114+ const uint8_t * src_bytes = reinterpret_cast <const uint8_t *>(packed.ptr );
115+ raw0 = hn::BitCast (df,
116+ hn::LoadU (du8, src_bytes + packed_ofs * sizeof (Packed)));
117+ raw1 = hn::BitCast (
118+ df, hn::LoadU (du8, src_bytes + (packed_ofs + N) * sizeof (Packed)));
109119 }
110120
111121 template <class DD , HWY_IF_F64_D(DD), class VD = hn::Vec<DD>>
@@ -114,9 +124,12 @@ struct CompressTraits<float> {
114124 const hn::Rebind<float , DD> df;
115125 using VF = hn::Vec<decltype (df)>;
116126 const size_t NF = hn::Lanes (df);
117- // Two half loads are likely cheaper than one full + UpperHalf.
118- const VF f0 = hn::LoadU (df, packed.ptr + packed_ofs + 0 * NF);
119- const VF f1 = hn::LoadU (df, packed.ptr + packed_ofs + 1 * NF);
127+ const hn::Repartition<uint8_t , decltype (df)> du8;
128+ const uint8_t * src_bytes = reinterpret_cast <const uint8_t *>(packed.ptr );
129+ const VF f0 = hn::BitCast (
130+ df, hn::LoadU (du8, src_bytes + (packed_ofs + 0 * NF) * sizeof (Packed)));
131+ const VF f1 = hn::BitCast (
132+ df, hn::LoadU (du8, src_bytes + (packed_ofs + 1 * NF) * sizeof (Packed)));
120133 raw0 = hn::PromoteTo (dd, f0);
121134 raw1 = hn::PromoteTo (dd, f1);
122135 }
@@ -128,21 +141,31 @@ struct CompressTraits<float> {
128141 const hn::Repartition<float , decltype (dbf)> df;
129142 using VF = hn::Vec<decltype (df)>;
130143 const size_t NF = hn::Lanes (df);
144+ const hn::Repartition<uint8_t , decltype (dbf)> du8;
145+ const uint8_t * src_bytes = reinterpret_cast <const uint8_t *>(packed.ptr );
131146
132147 size_t i = 0 ;
133148 if (num >= 2 * NF) {
134149 for (; i <= num - 2 * NF; i += 2 * NF) {
135- const VF f0 = hn::LoadU (df, packed.ptr + packed_ofs + i);
136- const VF f1 = hn::LoadU (df, packed.ptr + packed_ofs + i + NF);
150+ const VF f0 = hn::BitCast (
151+ df, hn::LoadU (du8, src_bytes + (packed_ofs + i) * sizeof (Packed)));
152+ const VF f1 = hn::BitCast (
153+ df,
154+ hn::LoadU (du8, src_bytes + (packed_ofs + i + NF) * sizeof (Packed)));
137155 hn::StoreU (hn::OrderedDemote2To (dbf, f0, f1), dbf, raw + i);
138156 }
139157 }
140158 const size_t remaining = num - i;
141159 HWY_DASSERT (remaining < 2 * NF);
142160 if (HWY_UNLIKELY (remaining != 0 )) {
143161 const size_t remaining2 = remaining - HWY_MIN (remaining, NF);
144- const VF f0 = hn::LoadN (df, packed.ptr + packed_ofs + i, remaining);
145- const VF f1 = hn::LoadN (df, packed.ptr + packed_ofs + i + NF, remaining2);
162+ const VF f0 = hn::BitCast (
163+ df, hn::LoadN (du8, src_bytes + (packed_ofs + i) * sizeof (Packed),
164+ HWY_MIN (remaining, NF) * sizeof (Packed)));
165+ const VF f1 = hn::BitCast (
166+ df, hn::LoadN (du8,
167+ src_bytes + (packed_ofs + i + NF) * sizeof (Packed),
168+ remaining2 * sizeof (Packed)));
146169 hn::StoreU (hn::OrderedDemote2To (dbf, f0, f1), dbf, raw + i);
147170 }
148171 }
@@ -153,18 +176,23 @@ struct CompressTraits<float> {
153176 float * HWY_RESTRICT raw, size_t num) {
154177 using VF = hn::Vec<decltype (df)>;
155178 const size_t NF = hn::Lanes (df);
179+ const hn::Repartition<uint8_t , DF> du8;
180+ const uint8_t * src_bytes = reinterpret_cast <const uint8_t *>(packed.ptr );
156181
157182 size_t i = 0 ;
158183 if (num >= NF) {
159184 for (; i <= num - NF; i += NF) {
160- const VF vf = hn::LoadU (df, packed.ptr + packed_ofs + i);
185+ const VF vf = hn::BitCast (
186+ df, hn::LoadU (du8, src_bytes + (packed_ofs + i) * sizeof (Packed)));
161187 hn::StoreU (vf, df, raw + i);
162188 }
163189 }
164190 const size_t remaining = num - i;
165191 HWY_DASSERT (remaining < NF);
166192 if (HWY_UNLIKELY (remaining != 0 )) {
167- const VF vf = hn::LoadN (df, packed.ptr + packed_ofs + i, remaining);
193+ const VF vf = hn::BitCast (
194+ df, hn::LoadN (du8, src_bytes + (packed_ofs + i) * sizeof (Packed),
195+ remaining * sizeof (Packed)));
168196 hn::StoreU (vf, df, raw + i); // adds zero padding
169197 }
170198 }
@@ -176,18 +204,23 @@ struct CompressTraits<float> {
176204 const hn::Rebind<float , DD> df;
177205 using VF = hn::Vec<decltype (df)>;
178206 const size_t ND = hn::Lanes (dd);
207+ const hn::Repartition<uint8_t , decltype (df)> du8;
208+ const uint8_t * src_bytes = reinterpret_cast <const uint8_t *>(packed.ptr );
179209
180210 size_t i = 0 ;
181211 if (num >= ND) {
182212 for (; i <= num - ND; i += ND) {
183- const VF vf = hn::LoadU (df, packed.ptr + packed_ofs + i);
213+ const VF vf = hn::BitCast (
214+ df, hn::LoadU (du8, src_bytes + (packed_ofs + i) * sizeof (Packed)));
184215 hn::StoreU (hn::PromoteTo (dd, vf), dd, raw + i);
185216 }
186217 }
187218 const size_t remaining = num - i;
188219 HWY_DASSERT (remaining < ND);
189220 if (HWY_UNLIKELY (remaining != 0 )) {
190- const VF vf = hn::LoadN (df, packed.ptr + packed_ofs + i, remaining);
221+ const VF vf = hn::BitCast (
222+ df, hn::LoadN (du8, src_bytes + (packed_ofs + i) * sizeof (Packed),
223+ remaining * sizeof (Packed)));
191224 hn::StoreU (hn::PromoteTo (dd, vf), dd, raw + i); // adds zero padding
192225 }
193226 }
@@ -265,9 +298,13 @@ struct CompressTraits<BF16> {
265298 const PackedSpan<const Packed>& packed,
266299 const size_t packed_ofs, hn::Vec<DBF16>& raw0,
267300 hn::Vec<DBF16>& raw1) {
301+ const hn::Repartition<uint8_t , DBF16> du8;
268302 const size_t N16 = hn::Lanes (dbf16);
269- raw0 = hn::LoadU (dbf16, packed.ptr + packed_ofs);
270- raw1 = hn::LoadU (dbf16, packed.ptr + packed_ofs + N16);
303+ const uint8_t * src_bytes = reinterpret_cast <const uint8_t *>(packed.ptr );
304+ raw0 = hn::BitCast (dbf16,
305+ hn::LoadU (du8, src_bytes + packed_ofs * sizeof (Packed)));
306+ raw1 = hn::BitCast (
307+ dbf16, hn::LoadU (du8, src_bytes + (packed_ofs + N16) * sizeof (Packed)));
271308 }
272309
273310 template <class DF , HWY_IF_F32_D(DF)>
@@ -276,7 +313,10 @@ struct CompressTraits<BF16> {
276313 hn::Vec<DF>& raw1) {
277314 const hn::Repartition<BF16, decltype (df)> dbf;
278315 using VBF = hn::Vec<decltype (dbf)>;
279- const VBF packed0 = hn::LoadU (dbf, packed.ptr + packed_ofs);
316+ const hn::Repartition<uint8_t , decltype (df)> du8;
317+ const uint8_t * src_bytes = reinterpret_cast <const uint8_t *>(packed.ptr );
318+ const VBF packed0 = hn::BitCast (
319+ dbf, hn::LoadU (du8, src_bytes + packed_ofs * sizeof (Packed)));
280320 raw0 = hn::PromoteLowerTo (df, packed0);
281321 raw1 = hn::PromoteUpperTo (df, packed0);
282322 }
@@ -287,20 +327,24 @@ struct CompressTraits<BF16> {
287327 BF16* HWY_RESTRICT raw, size_t num) {
288328 using VBF = hn::Vec<decltype (dbf)>;
289329 const size_t N16 = hn::Lanes (dbf);
330+ const hn::Repartition<uint8_t , DBF> du8;
331+ const uint8_t * src_bytes = reinterpret_cast <const uint8_t *>(packed.ptr );
290332
291333 size_t i = 0 ;
292334 if (num >= N16) {
293335 for (; i <= num - N16; i += N16) {
294- const VBF packed0 = hn::LoadU (dbf, packed.ptr + packed_ofs + i);
336+ const VBF packed0 = hn::BitCast (
337+ dbf, hn::LoadU (du8, src_bytes + (packed_ofs + i) * sizeof (Packed)));
295338 hn::StoreU (packed0, dbf, raw + i);
296339 }
297340 }
298341
299342 const size_t remaining = num - i;
300343 HWY_DASSERT (remaining < N16);
301344 if (HWY_UNLIKELY (remaining != 0 )) {
302- const VBF packed0 =
303- hn::LoadN (dbf, packed.ptr + packed_ofs + i, remaining);
345+ const VBF packed0 = hn::BitCast (
346+ dbf, hn::LoadN (du8, src_bytes + (packed_ofs + i) * sizeof (Packed),
347+ remaining * sizeof (Packed)));
304348 hn::StoreU (packed0, dbf, raw + i);
305349 }
306350 }
@@ -363,8 +407,11 @@ struct CompressTraits<BF16> {
363407 const size_t remaining = num - i;
364408 HWY_DASSERT (remaining < 2 * NF);
365409 if (HWY_UNLIKELY (remaining != 0 )) {
366- const VBF packed0 =
367- hn::LoadN (dbf, packed.ptr + packed_ofs + i, remaining);
410+ const hn::Repartition<uint8_t , decltype (dbf)> du8;
411+ const uint8_t * src_bytes = reinterpret_cast <const uint8_t *>(packed.ptr );
412+ const VBF packed0 = hn::BitCast (
413+ dbf, hn::LoadN (du8, src_bytes + (packed_ofs + i) * sizeof (Packed),
414+ remaining * sizeof (Packed)));
368415 const VF raw0 = hn::PromoteLowerTo (df, packed0);
369416 const VF raw1 = hn::PromoteUpperTo (df, packed0);
370417 // If at most one vector, the first store adds zero padding. Check before
0 commit comments