Skip to content

Commit 42e74b7

Browse files
committed
feat: add runtime batch_bool mask overloads for load_masked/store_masked
Add runtime-mask overloads of xsimd::load_masked and xsimd::store_masked across AVX2, AVX-512, SSE, SVE, RVV, and NEON. The generic common-path fallback is collapsed to a whole-vector select, and the unaligned page-cross fast path is dropped since the underlying intrinsics suppress faults on masked-off lanes regardless of alignment.
1 parent 5141ff0 commit 42e74b7

11 files changed

Lines changed: 462 additions & 11 deletions

File tree

docs/source/api/data_transfer.rst

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Data Transfers
1212
From memory:
1313

1414
+---------------------------------------+----------------------------------------------------+
15-
| :cpp:func:`load` | load values from memory (optionally masked) |
15+
| :cpp:func:`load` | load values from memory (optionally masked) [#m]_ |
1616
+---------------------------------------+----------------------------------------------------+
1717
| :cpp:func:`load_aligned` | load values from aligned memory |
1818
+---------------------------------------+----------------------------------------------------+
@@ -32,7 +32,7 @@ From a scalar:
3232
To memory:
3333

3434
+---------------------------------------+----------------------------------------------------+
35-
| :cpp:func:`store` | store values to memory (optionally masked) |
35+
| :cpp:func:`store` | store values to memory (optionally masked) [#m]_ |
3636
+---------------------------------------+----------------------------------------------------+
3737
| :cpp:func:`store_aligned` | store values to aligned memory |
3838
+---------------------------------------+----------------------------------------------------+
@@ -84,3 +84,16 @@ The following empty types are used for tag dispatching:
8484

8585
.. doxygenstruct:: xsimd::unaligned_mode
8686
:project: xsimd
87+
88+
.. rubric:: Footnotes
89+
90+
.. [#m] Masked ``load`` / ``store`` come in two flavours. The
91+
:cpp:class:`batch_bool_constant` overload encodes the mask in the type, is
92+
resolved at compile time and is always efficient. The runtime
93+
:cpp:class:`batch_bool` overload, by contrast, falls back to a per-lane
94+
scalar loop on architectures without a native masked load/store
95+
instruction — SSE2 through SSE4.2, NEON/NEON64, VSX, S390x, and WASM.
96+
AVX, AVX2, AVX-512, SVE and RVV use native masked instructions and pay no
97+
such penalty. Prefer the compile-time mask whenever the selection is known
98+
at compile time, and avoid runtime-mask loads/stores in hot inner loops on
99+
the affected architectures.

include/xsimd/arch/common/xsimd_common_memory.hpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <algorithm>
1616
#include <array>
1717
#include <complex>
18+
#include <cstdint>
1819

1920
#include "../../types/xsimd_batch_constant.hpp"
2021
#include "./xsimd_common_details.hpp"
@@ -374,6 +375,39 @@ namespace xsimd
374375
return batch<T_out, A>::load(buffer.data(), aligned_mode {});
375376
}
376377

378+
template <class A, class T>
379+
XSIMD_INLINE batch<T, A>
380+
load_masked(T const* mem, batch_bool<T, A> mask, convert<T>, aligned_mode, requires_arch<common>) noexcept
381+
{
382+
// Aligned mode contract: ``mem`` is aligned to ``A::alignment()``,
383+
// and ``A::alignment() >= sizeof(batch<T, A>)`` for every common-
384+
// fallback arch (SSE2-SSE4.2, NEON, NEON64, VSX, S390x, WASM — all
385+
// 16-byte aligned, 16-byte vectors). The whole vector therefore
386+
// lives inside a single alignment unit (and a single page, since
387+
// pages are >= alignment), so an unconditional load cannot fault
388+
// on inactive lanes. Lower the masked load to ``select`` against a
389+
// zero broadcast — collapses to ~3 SIMD ops on every fallback arch.
390+
return select(mask,
391+
batch<T, A>::load_aligned(mem),
392+
batch<T, A>(T(0)));
393+
}
394+
395+
template <class A, class T>
396+
XSIMD_INLINE batch<T, A>
397+
load_masked(T const* mem, batch_bool<T, A> mask, convert<T>, unaligned_mode, requires_arch<common>) noexcept
398+
{
399+
// Unaligned + runtime mask: ``mem`` may straddle a page boundary
400+
// whose neighbour is unmapped, so an unconditional whole-vector
401+
// ``load_unaligned`` is unsafe. Stay scalar.
402+
constexpr std::size_t size = batch<T, A>::size;
403+
alignas(A::alignment()) std::array<T, size> buffer {};
404+
const uint64_t bits = mask.mask();
405+
for (std::size_t i = 0; i < size; ++i)
406+
if ((bits >> i) & uint64_t(1))
407+
buffer[i] = mem[i];
408+
return batch<T, A>::load_aligned(buffer.data());
409+
}
410+
377411
template <class A, class T_in, class T_out, bool... Values, class alignment>
378412
XSIMD_INLINE void
379413
store_masked(T_out* mem, batch<T_in, A> const& src, batch_bool_constant<T_in, A, Values...>, alignment, requires_arch<common>) noexcept
@@ -388,6 +422,33 @@ namespace xsimd
388422
}
389423
}
390424

425+
template <class A, class T>
426+
XSIMD_INLINE void
427+
store_masked(T* mem, batch<T, A> const& src, batch_bool<T, A> mask, aligned_mode, requires_arch<common>) noexcept
428+
{
429+
// Symmetric to load_masked: aligned ``mem`` cannot fault for any
430+
// lane in the batch, so a read-modify-write through ``select`` is
431+
// safe and collapses to load + select + store on every fallback
432+
// arch.
433+
const auto current = batch<T, A>::load_aligned(mem);
434+
select(mask, src, current).store_aligned(mem);
435+
}
436+
437+
template <class A, class T>
438+
XSIMD_INLINE void
439+
store_masked(T* mem, batch<T, A> const& src, batch_bool<T, A> mask, unaligned_mode, requires_arch<common>) noexcept
440+
{
441+
// Symmetric to the unaligned load: unaligned RMW could fault on a
442+
// page boundary, so stay scalar.
443+
constexpr std::size_t size = batch<T, A>::size;
444+
alignas(A::alignment()) std::array<T, size> src_buf;
445+
src.store_aligned(src_buf.data());
446+
const uint64_t bits = mask.mask();
447+
for (std::size_t i = 0; i < size; ++i)
448+
if ((bits >> i) & uint64_t(1))
449+
mem[i] = src_buf[i];
450+
}
451+
391452
template <class A, bool... Values, class Mode>
392453
XSIMD_INLINE batch<int32_t, A> load_masked(int32_t const* mem, batch_bool_constant<int32_t, A, Values...>, convert<int32_t>, Mode, requires_arch<A>) noexcept
393454
{

include/xsimd/arch/xsimd_avx.hpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,6 +1015,23 @@ namespace xsimd
10151015
}
10161016
}
10171017

1018+
// Runtime-mask load for float/double on AVX. Both aligned_mode and
1019+
// unaligned_mode map to _mm256_maskload_* — the intrinsic does not fault
1020+
// on masked-off lanes, so partial loads across page boundaries are safe.
1021+
template <class A, class Mode>
1022+
XSIMD_INLINE batch<float, A>
1023+
load_masked(float const* mem, batch_bool<float, A> mask, convert<float>, Mode, requires_arch<avx>) noexcept
1024+
{
1025+
return _mm256_maskload_ps(mem, _mm256_castps_si256(mask));
1026+
}
1027+
1028+
template <class A, class Mode>
1029+
XSIMD_INLINE batch<double, A>
1030+
load_masked(double const* mem, batch_bool<double, A> mask, convert<double>, Mode, requires_arch<avx>) noexcept
1031+
{
1032+
return _mm256_maskload_pd(mem, _mm256_castpd_si256(mask));
1033+
}
1034+
10181035
// store_masked
10191036
namespace detail
10201037
{
@@ -1031,6 +1048,22 @@ namespace xsimd
10311048
}
10321049
}
10331050

1051+
// Runtime-mask store for float/double on AVX. Same fault-suppression
1052+
// semantics as the masked loads above; alignment mode is irrelevant.
1053+
template <class A, class Mode>
1054+
XSIMD_INLINE void
1055+
store_masked(float* mem, batch<float, A> const& src, batch_bool<float, A> mask, Mode, requires_arch<avx>) noexcept
1056+
{
1057+
_mm256_maskstore_ps(mem, _mm256_castps_si256(mask), src);
1058+
}
1059+
1060+
template <class A, class Mode>
1061+
XSIMD_INLINE void
1062+
store_masked(double* mem, batch<double, A> const& src, batch_bool<double, A> mask, Mode, requires_arch<avx>) noexcept
1063+
{
1064+
_mm256_maskstore_pd(mem, _mm256_castpd_si256(mask), src);
1065+
}
1066+
10341067
template <class A, class T, bool... Values, class Mode>
10351068
XSIMD_INLINE void store_masked(T* mem, batch<T, A> const& src, batch_bool_constant<T, A, Values...> mask, Mode, requires_arch<avx>) noexcept
10361069
{

include/xsimd/arch/xsimd_avx2.hpp

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,6 @@ namespace xsimd
119119
}
120120

121121
// load_masked
122-
// AVX2 low-level helpers (operate on raw SIMD registers)
123122
namespace detail
124123
{
125124
XSIMD_INLINE __m256i maskload(const int32_t* mem, __m256i mask) noexcept
@@ -138,14 +137,12 @@ namespace xsimd
138137
}
139138
}
140139

141-
// single templated implementation for integer masked loads (32/64-bit)
142140
template <class A, class T, bool... Values, class Mode>
143141
XSIMD_INLINE std::enable_if_t<std::is_integral<T>::value && (sizeof(T) >= 4), batch<T, A>>
144142
load_masked(T const* mem, batch_bool_constant<T, A, Values...> mask, convert<T>, Mode, requires_arch<avx2>) noexcept
145143
{
146144
static_assert(sizeof(T) == 4 || sizeof(T) == 8, "load_masked supports only 32/64-bit integers on AVX2");
147145
using int_t = std::conditional_t<sizeof(T) == 4, int32_t, long long>;
148-
// Use the raw register-level maskload helpers for the remaining cases.
149146
return detail::maskload(reinterpret_cast<const int_t*>(mem), mask.as_batch());
150147
}
151148

@@ -175,6 +172,20 @@ namespace xsimd
175172
return bitwise_cast<uint64_t>(r);
176173
}
177174

175+
// Runtime-mask load for 32/64-bit integers on AVX2. 8/16-bit integers
176+
// fall back to the scalar common path: AVX2 has no native maskload for
177+
// those widths, and a load-then-blend would break fault-suppression at
178+
// page boundaries (the main reason callers ask for a masked load).
179+
// Both aligned_mode and unaligned_mode route to the same intrinsic —
180+
// masked-off lanes do not fault regardless of alignment.
181+
template <class A, class T, class Mode>
182+
XSIMD_INLINE std::enable_if_t<std::is_integral<T>::value && (sizeof(T) == 4 || sizeof(T) == 8), batch<T, A>>
183+
load_masked(T const* mem, batch_bool<T, A> mask, convert<T>, Mode, requires_arch<avx2>) noexcept
184+
{
185+
using int_t = std::conditional_t<sizeof(T) == 4, int32_t, long long>;
186+
return detail::maskload(reinterpret_cast<const int_t*>(mem), __m256i(mask));
187+
}
188+
178189
// store_masked
179190
namespace detail
180191
{
@@ -196,14 +207,12 @@ namespace xsimd
196207
{
197208
constexpr size_t lanes_per_half = batch<T, A>::size / 2;
198209

199-
// confined to lower 128-bit half → forward to SSE
200210
XSIMD_IF_CONSTEXPR(mask.countl_zero() >= lanes_per_half)
201211
{
202212
constexpr auto mlo = ::xsimd::detail::lower_half<sse4_2>(mask);
203213
const auto lo = detail::lower_half(src);
204214
store_masked<sse4_2>(mem, lo, mlo, Mode {}, sse4_2 {});
205215
}
206-
// confined to upper 128-bit half → forward to SSE
207216
else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= lanes_per_half)
208217
{
209218
constexpr auto mhi = ::xsimd::detail::upper_half<sse4_2>(mask);
@@ -230,6 +239,20 @@ namespace xsimd
230239
store_masked<A>(reinterpret_cast<int64_t*>(mem), s64, batch_bool_constant<int64_t, A, Values...> {}, Mode {}, avx2 {});
231240
}
232241

242+
template <class A, class T, class Mode>
243+
XSIMD_INLINE std::enable_if_t<std::is_integral<T>::value && (sizeof(T) == 4 || sizeof(T) == 8), void>
244+
store_masked(T* mem, batch<T, A> const& src, batch_bool<T, A> mask, Mode, requires_arch<avx2>) noexcept
245+
{
246+
XSIMD_IF_CONSTEXPR(sizeof(T) == 4)
247+
{
248+
_mm256_maskstore_epi32(reinterpret_cast<int*>(mem), __m256i(mask), __m256i(src));
249+
}
250+
else
251+
{
252+
_mm256_maskstore_epi64(reinterpret_cast<long long*>(mem), __m256i(mask), __m256i(src));
253+
}
254+
}
255+
233256
// load_stream
234257
template <class A, class T, class = std::enable_if_t<std::is_integral<T>::value, void>>
235258
XSIMD_INLINE batch<T, A> load_stream(T const* mem, convert<T>, requires_arch<avx2>) noexcept

include/xsimd/arch/xsimd_common_fwd.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,16 @@ namespace xsimd
7979
XSIMD_INLINE batch<T, A> load(T const* mem, unaligned_mode, requires_arch<A>) noexcept;
8080
template <class A, class T_in, class T_out, bool... Values, class alignment>
8181
XSIMD_INLINE batch<T_out, A> load_masked(T_in const* mem, batch_bool_constant<T_out, A, Values...> mask, convert<T_out>, alignment, requires_arch<common>) noexcept;
82+
template <class A, class T>
83+
XSIMD_INLINE batch<T, A> load_masked(T const* mem, batch_bool<T, A> mask, convert<T>, aligned_mode, requires_arch<common>) noexcept;
84+
template <class A, class T>
85+
XSIMD_INLINE batch<T, A> load_masked(T const* mem, batch_bool<T, A> mask, convert<T>, unaligned_mode, requires_arch<common>) noexcept;
8286
template <class A, class T_in, class T_out, bool... Values, class alignment>
8387
XSIMD_INLINE void store_masked(T_out* mem, batch<T_in, A> const& src, batch_bool_constant<T_in, A, Values...> mask, alignment, requires_arch<common>) noexcept;
88+
template <class A, class T>
89+
XSIMD_INLINE void store_masked(T* mem, batch<T, A> const& src, batch_bool<T, A> mask, aligned_mode, requires_arch<common>) noexcept;
90+
template <class A, class T>
91+
XSIMD_INLINE void store_masked(T* mem, batch<T, A> const& src, batch_bool<T, A> mask, unaligned_mode, requires_arch<common>) noexcept;
8492
template <class A, bool... Values, class Mode>
8593
XSIMD_INLINE batch<int32_t, A> load_masked(int32_t const* mem, batch_bool_constant<int32_t, A, Values...> mask, convert<int32_t>, Mode, requires_arch<A>) noexcept;
8694
template <class A, bool... Values, class Mode>

include/xsimd/arch/xsimd_rvv.hpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,11 @@ namespace xsimd
409409
{
410410
XSIMD_RVV_OVERLOAD(rvvle, (__riscv_vle XSIMD_RVV_S _v_ XSIMD_RVV_TSM), , vec(T const*))
411411
XSIMD_RVV_OVERLOAD(rvvse, (__riscv_vse XSIMD_RVV_S _v_ XSIMD_RVV_TSM), , void(T*, vec))
412+
// Masked load (mask-undisturbed with zero passthrough): inactive lanes read as 0,
413+
// no memory access is performed for inactive lanes (page-fault safe).
414+
XSIMD_RVV_OVERLOAD(rvvle_mu, (__riscv_vle XSIMD_RVV_S _v_ XSIMD_RVV_TSM _mu), , vec(bvec, vec, T const*))
415+
// Masked store: inactive lanes are not written.
416+
XSIMD_RVV_OVERLOAD(rvvse_m, (__riscv_vse XSIMD_RVV_S _v_ XSIMD_RVV_TSM _m), , void(bvec, T*, vec))
412417
}
413418

414419
template <class A, class T, detail::enable_arithmetic_t<T> = 0>
@@ -423,6 +428,16 @@ namespace xsimd
423428
return load_aligned<A>(src, convert<T>(), rvv {});
424429
}
425430

431+
// load_masked (runtime mask): native vle*.v vd, (rs1), v0.t with zero-init
432+
// passthrough so inactive lanes read as 0, matching xsimd's contract.
433+
template <class A, class T, class Mode, detail::enable_arithmetic_t<T> = 0>
434+
XSIMD_INLINE batch<T, A> load_masked(T const* mem, batch_bool<T, A> mask, convert<T>, Mode, requires_arch<rvv>) noexcept
435+
{
436+
using proj_t = project_num_t<T>;
437+
const auto zero = detail_rvv::rvvmv_splat(proj_t {});
438+
return detail_rvv::rvvle_mu(mask, zero, reinterpret_cast<proj_t const*>(mem));
439+
}
440+
426441
// load_complex
427442
namespace detail_rvv
428443
{
@@ -500,6 +515,15 @@ namespace xsimd
500515
store_aligned<A>(dst, src, rvv {});
501516
}
502517

518+
// store_masked (runtime mask): native vse*.v vd, (rs1), v0.t — inactive lanes
519+
// are not written (page-fault safe).
520+
template <class A, class T, class Mode, detail::enable_arithmetic_t<T> = 0>
521+
XSIMD_INLINE void store_masked(T* mem, batch<T, A> const& src, batch_bool<T, A> mask, Mode, requires_arch<rvv>) noexcept
522+
{
523+
using proj_t = project_num_t<T>;
524+
detail_rvv::rvvse_m(mask, reinterpret_cast<proj_t*>(mem), src);
525+
}
526+
503527
/******************
504528
* scatter/gather *
505529
******************/

include/xsimd/arch/xsimd_sve.hpp

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,20 @@ namespace xsimd
101101
return load_aligned<A>(src, convert<T>(), sve {});
102102
}
103103

104-
// load_masked
104+
// load_masked (compile-time mask)
105105
template <class A, class T, bool... Values, class Mode, detail::enable_arithmetic_t<T> = 0>
106-
XSIMD_INLINE batch<T, A> load_masked(T const* mem, batch_bool_constant<float, A, Values...>, Mode, requires_arch<sve>) noexcept
106+
XSIMD_INLINE batch<T, A> load_masked(T const* mem, batch_bool_constant<T, A, Values...>, convert<T>, Mode, requires_arch<sve>) noexcept
107107
{
108108
return svld1(detail_sve::pmask<Values...>(), reinterpret_cast<map_to_sized_type_t<T> const*>(mem));
109109
}
110110

111+
// load_masked (runtime mask)
112+
template <class A, class T, class Mode, detail::enable_arithmetic_t<T> = 0>
113+
XSIMD_INLINE batch<T, A> load_masked(T const* mem, batch_bool<T, A> mask, convert<T>, Mode, requires_arch<sve>) noexcept
114+
{
115+
return svld1(mask, reinterpret_cast<project_num_t<T> const*>(mem));
116+
}
117+
111118
// load_complex
112119
template <class A, class T, detail::enable_floating_point_t<T> = 0>
113120
XSIMD_INLINE batch<std::complex<T>, A> load_complex_aligned(std::complex<T> const* mem, convert<std::complex<T>>, requires_arch<sve>) noexcept
@@ -141,6 +148,20 @@ namespace xsimd
141148
store_aligned<A>(dst, src, sve {});
142149
}
143150

151+
// store_masked (compile-time mask)
152+
template <class A, class T, bool... Values, class Mode, detail::enable_arithmetic_t<T> = 0>
153+
XSIMD_INLINE void store_masked(T* mem, batch<T, A> const& src, batch_bool_constant<T, A, Values...>, Mode, requires_arch<sve>) noexcept
154+
{
155+
svst1(detail_sve::pmask<Values...>(), reinterpret_cast<project_num_t<T>*>(mem), src);
156+
}
157+
158+
// store_masked (runtime mask)
159+
template <class A, class T, class Mode, detail::enable_arithmetic_t<T> = 0>
160+
XSIMD_INLINE void store_masked(T* mem, batch<T, A> const& src, batch_bool<T, A> mask, Mode, requires_arch<sve>) noexcept
161+
{
162+
svst1(mask, reinterpret_cast<project_num_t<T>*>(mem), src);
163+
}
164+
144165
// store_complex
145166
template <class A, class T, detail::enable_floating_point_t<T> = 0>
146167
XSIMD_INLINE void store_complex_aligned(std::complex<T>* dst, batch<std::complex<T>, A> const& src, requires_arch<sve>) noexcept

0 commit comments

Comments
 (0)