Skip to content

Commit b97d464

Browse files
committed
feat: add mulhi, mullo, and mulhilo for integer batches
Adds three integer-multiplication primitives exposed via the public API: - mullo(x, y): low half of the lane-wise product (equivalent to x * y) - mulhi(x, y): high half of the lane-wise product - mulhilo(x, y): returns {mulhi, mullo} as a pair Native kernels are provided for: - NEON (vmull_* + vshrn for 8/16/32-bit; software path for 64-bit) - SVE (svmulh_x) - RVV (rvvmulh) - SSE2 (mulhi_epi16 / mulhi_epu16) - SSE4.1 (mul_epi32/mul_epu32 + blend for 32-bit; shared 64-bit core) - AVX2 (mulhi_epi16/epu16, mul_epi32/mul_epu32 + blend; shared 64-bit core) - AVX-512F (shared 64-bit core) - AVX-512BW (mulhi_epi16/epu16) The 64-bit x86 cores share a single implementation in common/xsimd_common_arithmetic.hpp: mulhi_u64_core and mulhi_i64_core express the ll/lh/hl/hh decomposition with xsimd batch operators (&, >>, +, -, bitwise_cast) plus an arch-specific widening-mul functor (_mm*_mul_epu32). This eliminates three copies of the same 64x64 -> hi software path and unifies the signed-fixup to a single arithmetic-shift-by-63 pattern (maps to vpsraq on AVX-512, emulated on SSE4.1/AVX2 via bitwise_rshift). The generic fallback in common dispatches per-type through mulhi_helper, using a wider native integer for <=32-bit types and software split-and- multiply (or __int128 when available) for 64-bit.
1 parent eb1c1ba commit b97d464

13 files changed

Lines changed: 567 additions & 0 deletions

docs/source/api/arithmetic_index.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ Binary operations:
4040
+---------------------------------------+----------------------------------------------------+
4141
| :cpp:func:`mul` | per slot multiply |
4242
+---------------------------------------+----------------------------------------------------+
43+
| :cpp:func:`mullo` | low N bits of the 2N-bit integer product |
44+
+---------------------------------------+----------------------------------------------------+
45+
| :cpp:func:`mulhi` | high N bits of the 2N-bit integer product |
46+
+---------------------------------------+----------------------------------------------------+
47+
| :cpp:func:`mulhilo` | pair {hi, lo} of the 2N-bit integer product |
48+
+---------------------------------------+----------------------------------------------------+
4349
| :cpp:func:`div` | per slot division |
4450
+---------------------------------------+----------------------------------------------------+
4551
| :cpp:func:`mod` | per slot modulo |

include/xsimd/arch/common/xsimd_common_arithmetic.hpp

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,118 @@ namespace xsimd
177177
self, other);
178178
}
179179

180+
// mulhi
181+
namespace detail
182+
{
183+
template <class T>
184+
struct mulhi_helper
185+
{
186+
// default: use a wider native integer type
187+
using wider = typename std::conditional<
188+
std::is_signed<T>::value,
189+
typename std::conditional<sizeof(T) == 1, int16_t,
190+
typename std::conditional<sizeof(T) == 2, int32_t, int64_t>::type>::type,
191+
typename std::conditional<sizeof(T) == 1, uint16_t,
192+
typename std::conditional<sizeof(T) == 2, uint32_t, uint64_t>::type>::type>::type;
193+
194+
static XSIMD_INLINE T compute(T x, T y) noexcept
195+
{
196+
constexpr int shift = 8 * sizeof(T);
197+
return static_cast<T>((static_cast<wider>(x) * static_cast<wider>(y)) >> shift);
198+
}
199+
};
200+
201+
// 64-bit unsigned software mulhi via 32-bit splits
202+
XSIMD_INLINE uint64_t mulhi_u64(uint64_t x, uint64_t y) noexcept
203+
{
204+
#if defined(__SIZEOF_INT128__)
205+
return static_cast<uint64_t>((static_cast<unsigned __int128>(x) * static_cast<unsigned __int128>(y)) >> 64);
206+
#else
207+
uint64_t xl = x & 0xffffffffULL;
208+
uint64_t xh = x >> 32;
209+
uint64_t yl = y & 0xffffffffULL;
210+
uint64_t yh = y >> 32;
211+
uint64_t ll = xl * yl;
212+
uint64_t lh = xl * yh;
213+
uint64_t hl = xh * yl;
214+
uint64_t hh = xh * yh;
215+
uint64_t mid = (ll >> 32) + (lh & 0xffffffffULL) + (hl & 0xffffffffULL);
216+
return hh + (lh >> 32) + (hl >> 32) + (mid >> 32);
217+
#endif
218+
}
219+
220+
XSIMD_INLINE int64_t mulhi_i64(int64_t x, int64_t y) noexcept
221+
{
222+
#if defined(__SIZEOF_INT128__)
223+
return static_cast<int64_t>((static_cast<__int128>(x) * static_cast<__int128>(y)) >> 64);
224+
#else
225+
uint64_t uhi = mulhi_u64(static_cast<uint64_t>(x), static_cast<uint64_t>(y));
226+
if (x < 0)
227+
uhi -= static_cast<uint64_t>(y);
228+
if (y < 0)
229+
uhi -= static_cast<uint64_t>(x);
230+
return static_cast<int64_t>(uhi);
231+
#endif
232+
}
233+
234+
template <>
235+
struct mulhi_helper<uint64_t>
236+
{
237+
static XSIMD_INLINE uint64_t compute(uint64_t x, uint64_t y) noexcept { return mulhi_u64(x, y); }
238+
};
239+
240+
template <>
241+
struct mulhi_helper<int64_t>
242+
{
243+
static XSIMD_INLINE int64_t compute(int64_t x, int64_t y) noexcept { return mulhi_i64(x, y); }
244+
};
245+
246+
// Compute the high 64 bits of each lane-wise 64x64 unsigned product,
247+
// given a "widening mul" functor WMul that takes two batch<uint64_t,A>
248+
// and returns batch<uint64_t,A> containing the 64-bit product of the
249+
// low 32 bits of each 64-bit lane (i.e. _mm*_mul_epu32 wrapped).
250+
template <class A, class WMul>
251+
XSIMD_INLINE batch<uint64_t, A> mulhi_u64_core(batch<uint64_t, A> const& x,
252+
batch<uint64_t, A> const& y,
253+
WMul mul_epu32) noexcept
254+
{
255+
using B = batch<uint64_t, A>;
256+
const B mask(uint64_t(0xffffffffULL));
257+
B xl = x & mask;
258+
B xh = x >> 32;
259+
B yl = y & mask;
260+
B yh = y >> 32;
261+
B ll = mul_epu32(xl, yl);
262+
B lh = mul_epu32(xl, yh);
263+
B hl = mul_epu32(xh, yl);
264+
B hh = mul_epu32(xh, yh);
265+
B mid = (ll >> 32) + (lh & mask) + (hl & mask);
266+
return hh + (lh >> 32) + (hl >> 32) + (mid >> 32);
267+
}
268+
269+
// Signed variant: unsigned core + sign fixup via arithmetic shift-by-63.
270+
template <class A, class WMul>
271+
XSIMD_INLINE batch<int64_t, A> mulhi_i64_core(batch<int64_t, A> const& x,
272+
batch<int64_t, A> const& y,
273+
WMul mul_epu32) noexcept
274+
{
275+
auto ux = ::xsimd::bitwise_cast<uint64_t>(x);
276+
auto uy = ::xsimd::bitwise_cast<uint64_t>(y);
277+
auto uhi = mulhi_u64_core<A>(ux, uy, mul_epu32);
278+
auto sa = ::xsimd::bitwise_cast<uint64_t>(x >> 63);
279+
auto sb = ::xsimd::bitwise_cast<uint64_t>(y >> 63);
280+
return ::xsimd::bitwise_cast<int64_t>(uhi - (uy & sa) - (ux & sb));
281+
}
282+
}
283+
284+
template <class A, class T, class /*=std::enable_if_t<std::is_integral<T>::value>*/>
285+
XSIMD_INLINE batch<T, A> mulhi(batch<T, A> const& self, batch<T, A> const& other, requires_arch<common>) noexcept
286+
{
287+
return detail::apply([](T x, T y) noexcept -> T
288+
{ return detail::mulhi_helper<T>::compute(x, y); },
289+
self, other);
290+
}
291+
180292
// rotl
181293
template <class A, class T, class STy>
182294
XSIMD_INLINE batch<T, A> rotl(batch<T, A> const& self, STy other, requires_arch<common>) noexcept

include/xsimd/arch/xsimd_avx2.hpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -928,6 +928,50 @@ namespace xsimd
928928
}
929929
}
930930

931+
// mulhi
932+
template <class A>
933+
XSIMD_INLINE batch<int16_t, A> mulhi(batch<int16_t, A> const& self, batch<int16_t, A> const& other, requires_arch<avx2>) noexcept
934+
{
935+
return _mm256_mulhi_epi16(self, other);
936+
}
937+
template <class A>
938+
XSIMD_INLINE batch<uint16_t, A> mulhi(batch<uint16_t, A> const& self, batch<uint16_t, A> const& other, requires_arch<avx2>) noexcept
939+
{
940+
return _mm256_mulhi_epu16(self, other);
941+
}
942+
template <class A>
943+
XSIMD_INLINE batch<int32_t, A> mulhi(batch<int32_t, A> const& self, batch<int32_t, A> const& other, requires_arch<avx2>) noexcept
944+
{
945+
__m256i even = _mm256_mul_epi32(self, other);
946+
__m256i odd = _mm256_mul_epi32(_mm256_shuffle_epi32(self, _MM_SHUFFLE(3, 3, 1, 1)),
947+
_mm256_shuffle_epi32(other, _MM_SHUFFLE(3, 3, 1, 1)));
948+
__m256i even_hi = _mm256_srli_epi64(even, 32);
949+
return _mm256_blend_epi16(even_hi, odd, 0xCC);
950+
}
951+
template <class A>
952+
XSIMD_INLINE batch<uint32_t, A> mulhi(batch<uint32_t, A> const& self, batch<uint32_t, A> const& other, requires_arch<avx2>) noexcept
953+
{
954+
__m256i even = _mm256_mul_epu32(self, other);
955+
__m256i odd = _mm256_mul_epu32(_mm256_srli_epi64(self, 32), _mm256_srli_epi64(other, 32));
956+
__m256i even_hi = _mm256_srli_epi64(even, 32);
957+
return _mm256_blend_epi16(even_hi, odd, 0xCC);
958+
}
959+
960+
template <class A>
961+
XSIMD_INLINE batch<uint64_t, A> mulhi(batch<uint64_t, A> const& self, batch<uint64_t, A> const& other, requires_arch<avx2>) noexcept
962+
{
963+
return detail::mulhi_u64_core<A>(self, other,
964+
[](batch<uint64_t, A> a, batch<uint64_t, A> b)
965+
{ return batch<uint64_t, A>(_mm256_mul_epu32(a, b)); });
966+
}
967+
template <class A>
968+
XSIMD_INLINE batch<int64_t, A> mulhi(batch<int64_t, A> const& self, batch<int64_t, A> const& other, requires_arch<avx2>) noexcept
969+
{
970+
return detail::mulhi_i64_core<A>(self, other,
971+
[](batch<uint64_t, A> a, batch<uint64_t, A> b)
972+
{ return batch<uint64_t, A>(_mm256_mul_epu32(a, b)); });
973+
}
974+
931975
// reduce_add
932976
template <class A, class T, class = std::enable_if_t<std::is_integral<T>::value>>
933977
XSIMD_INLINE T reduce_add(batch<T, A> const& self, requires_arch<avx2>) noexcept

include/xsimd/arch/xsimd_avx512bw.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,18 @@ namespace xsimd
470470
}
471471
}
472472

473+
// mulhi
474+
template <class A>
475+
XSIMD_INLINE batch<int16_t, A> mulhi(batch<int16_t, A> const& self, batch<int16_t, A> const& other, requires_arch<avx512bw>) noexcept
476+
{
477+
return _mm512_mulhi_epi16(self, other);
478+
}
479+
template <class A>
480+
XSIMD_INLINE batch<uint16_t, A> mulhi(batch<uint16_t, A> const& self, batch<uint16_t, A> const& other, requires_arch<avx512bw>) noexcept
481+
{
482+
return _mm512_mulhi_epu16(self, other);
483+
}
484+
473485
// neq
474486
template <class A, class T, class = std::enable_if_t<std::is_integral<T>::value>>
475487
XSIMD_INLINE batch_bool<T, A> neq(batch<T, A> const& self, batch<T, A> const& other, requires_arch<avx512bw>) noexcept

include/xsimd/arch/xsimd_avx512f.hpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1772,6 +1772,41 @@ namespace xsimd
17721772
}
17731773
}
17741774

1775+
// mulhi
1776+
template <class A>
1777+
XSIMD_INLINE batch<int32_t, A> mulhi(batch<int32_t, A> const& self, batch<int32_t, A> const& other, requires_arch<avx512f>) noexcept
1778+
{
1779+
__m512i even = _mm512_mul_epi32(self, other);
1780+
__m512i odd = _mm512_mul_epi32(_mm512_shuffle_epi32(self, _MM_PERM_ENUM(_MM_SHUFFLE(3, 3, 1, 1))),
1781+
_mm512_shuffle_epi32(other, _MM_PERM_ENUM(_MM_SHUFFLE(3, 3, 1, 1))));
1782+
__m512i even_hi = _mm512_srli_epi64(even, 32);
1783+
// merge: even_hi has hi in low-32 of each 64, odd has hi in high-32 of each 64
1784+
return _mm512_mask_blend_epi32(static_cast<__mmask16>(0xAAAA), even_hi, odd);
1785+
}
1786+
template <class A>
1787+
XSIMD_INLINE batch<uint32_t, A> mulhi(batch<uint32_t, A> const& self, batch<uint32_t, A> const& other, requires_arch<avx512f>) noexcept
1788+
{
1789+
__m512i even = _mm512_mul_epu32(self, other);
1790+
__m512i odd = _mm512_mul_epu32(_mm512_srli_epi64(self, 32), _mm512_srli_epi64(other, 32));
1791+
__m512i even_hi = _mm512_srli_epi64(even, 32);
1792+
return _mm512_mask_blend_epi32(static_cast<__mmask16>(0xAAAA), even_hi, odd);
1793+
}
1794+
1795+
template <class A>
1796+
XSIMD_INLINE batch<uint64_t, A> mulhi(batch<uint64_t, A> const& self, batch<uint64_t, A> const& other, requires_arch<avx512f>) noexcept
1797+
{
1798+
return detail::mulhi_u64_core<A>(self, other,
1799+
[](batch<uint64_t, A> a, batch<uint64_t, A> b)
1800+
{ return batch<uint64_t, A>(_mm512_mul_epu32(a, b)); });
1801+
}
1802+
template <class A>
1803+
XSIMD_INLINE batch<int64_t, A> mulhi(batch<int64_t, A> const& self, batch<int64_t, A> const& other, requires_arch<avx512f>) noexcept
1804+
{
1805+
return detail::mulhi_i64_core<A>(self, other,
1806+
[](batch<uint64_t, A> a, batch<uint64_t, A> b)
1807+
{ return batch<uint64_t, A>(_mm512_mul_epu32(a, b)); });
1808+
}
1809+
17751810
// nearbyint
17761811
template <class A>
17771812
XSIMD_INLINE batch<float, A> nearbyint(batch<float, A> const& self, requires_arch<avx512f>) noexcept

include/xsimd/arch/xsimd_common_fwd.hpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ namespace xsimd
5858
template <class A, class T, class = std::enable_if_t<std::is_integral<T>::value>>
5959
XSIMD_INLINE batch<T, A> mul(batch<T, A> const& self, batch<T, A> const& other, requires_arch<common>) noexcept;
6060
template <class A, class T, class = std::enable_if_t<std::is_integral<T>::value>>
61+
XSIMD_INLINE batch<T, A> mulhi(batch<T, A> const& self, batch<T, A> const& other, requires_arch<common>) noexcept;
62+
template <class A, class T, class = std::enable_if_t<std::is_integral<T>::value>>
6163
XSIMD_INLINE batch<T, A> sadd(batch<T, A> const& self, batch<T, A> const& other, requires_arch<common>) noexcept;
6264
template <class A, class T, class = std::enable_if_t<std::is_integral<T>::value>>
6365
XSIMD_INLINE batch<T, A> ssub(batch<T, A> const& self, batch<T, A> const& other, requires_arch<common>) noexcept;
@@ -120,6 +122,18 @@ namespace xsimd
120122
XSIMD_INLINE constexpr bool is_only_from_lo(batch_constant<T, A, Vs...>) noexcept;
121123
template <typename T, class A, T... Vs>
122124
XSIMD_INLINE constexpr bool is_only_from_hi(batch_constant<T, A, Vs...>) noexcept;
125+
126+
// Shared 64-bit mulhi cores, defined in xsimd_common_arithmetic.hpp.
127+
// Forward-declared here so arch-specific kernels (SSE4.1, AVX2,
128+
// AVX-512) can name them with an explicit template argument.
129+
template <class A, class WMul>
130+
XSIMD_INLINE batch<uint64_t, A> mulhi_u64_core(batch<uint64_t, A> const& x,
131+
batch<uint64_t, A> const& y,
132+
WMul mul_epu32) noexcept;
133+
template <class A, class WMul>
134+
XSIMD_INLINE batch<int64_t, A> mulhi_i64_core(batch<int64_t, A> const& x,
135+
batch<int64_t, A> const& y,
136+
WMul mul_epu32) noexcept;
123137
}
124138
}
125139
}

include/xsimd/arch/xsimd_neon.hpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,6 +1037,54 @@ namespace xsimd
10371037
return wrap::x_vmulq<map_to_sized_type_t<T>>(register_type(lhs), register_type(rhs));
10381038
}
10391039

1040+
/*********
1041+
* mulhi *
1042+
*********/
1043+
1044+
template <class A>
1045+
XSIMD_INLINE batch<int8_t, A> mulhi(batch<int8_t, A> const& lhs, batch<int8_t, A> const& rhs, requires_arch<neon>) noexcept
1046+
{
1047+
int16x8_t lo = vmull_s8(vget_low_s8(lhs), vget_low_s8(rhs));
1048+
int16x8_t hi = vmull_s8(vget_high_s8(lhs), vget_high_s8(rhs));
1049+
return vcombine_s8(vshrn_n_s16(lo, 8), vshrn_n_s16(hi, 8));
1050+
}
1051+
template <class A>
1052+
XSIMD_INLINE batch<uint8_t, A> mulhi(batch<uint8_t, A> const& lhs, batch<uint8_t, A> const& rhs, requires_arch<neon>) noexcept
1053+
{
1054+
uint16x8_t lo = vmull_u8(vget_low_u8(lhs), vget_low_u8(rhs));
1055+
uint16x8_t hi = vmull_u8(vget_high_u8(lhs), vget_high_u8(rhs));
1056+
return vcombine_u8(vshrn_n_u16(lo, 8), vshrn_n_u16(hi, 8));
1057+
}
1058+
template <class A>
1059+
XSIMD_INLINE batch<int16_t, A> mulhi(batch<int16_t, A> const& lhs, batch<int16_t, A> const& rhs, requires_arch<neon>) noexcept
1060+
{
1061+
int32x4_t lo = vmull_s16(vget_low_s16(lhs), vget_low_s16(rhs));
1062+
int32x4_t hi = vmull_s16(vget_high_s16(lhs), vget_high_s16(rhs));
1063+
return vcombine_s16(vshrn_n_s32(lo, 16), vshrn_n_s32(hi, 16));
1064+
}
1065+
template <class A>
1066+
XSIMD_INLINE batch<uint16_t, A> mulhi(batch<uint16_t, A> const& lhs, batch<uint16_t, A> const& rhs, requires_arch<neon>) noexcept
1067+
{
1068+
uint32x4_t lo = vmull_u16(vget_low_u16(lhs), vget_low_u16(rhs));
1069+
uint32x4_t hi = vmull_u16(vget_high_u16(lhs), vget_high_u16(rhs));
1070+
return vcombine_u16(vshrn_n_u32(lo, 16), vshrn_n_u32(hi, 16));
1071+
}
1072+
template <class A>
1073+
XSIMD_INLINE batch<int32_t, A> mulhi(batch<int32_t, A> const& lhs, batch<int32_t, A> const& rhs, requires_arch<neon>) noexcept
1074+
{
1075+
int64x2_t lo = vmull_s32(vget_low_s32(lhs), vget_low_s32(rhs));
1076+
int64x2_t hi = vmull_s32(vget_high_s32(lhs), vget_high_s32(rhs));
1077+
return vcombine_s32(vshrn_n_s64(lo, 32), vshrn_n_s64(hi, 32));
1078+
}
1079+
template <class A>
1080+
XSIMD_INLINE batch<uint32_t, A> mulhi(batch<uint32_t, A> const& lhs, batch<uint32_t, A> const& rhs, requires_arch<neon>) noexcept
1081+
{
1082+
uint64x2_t lo = vmull_u32(vget_low_u32(lhs), vget_low_u32(rhs));
1083+
uint64x2_t hi = vmull_u32(vget_high_u32(lhs), vget_high_u32(rhs));
1084+
return vcombine_u32(vshrn_n_u64(lo, 32), vshrn_n_u64(hi, 32));
1085+
}
1086+
// 64-bit intentionally falls through to the common scalar fallback
1087+
10401088
/*******
10411089
* div *
10421090
*******/

include/xsimd/arch/xsimd_rvv.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,9 @@ namespace xsimd
567567
(__riscv_vmul),
568568
(__riscv_vmul),
569569
(__riscv_vfmul), , vec(vec, vec))
570+
XSIMD_RVV_OVERLOAD2(rvvmulh,
571+
(__riscv_vmulh),
572+
(__riscv_vmulhu), , vec(vec, vec))
570573
XSIMD_RVV_OVERLOAD3(rvvdiv,
571574
(__riscv_vdiv),
572575
(__riscv_vdivu),
@@ -659,6 +662,13 @@ namespace xsimd
659662
return detail_rvv::rvvmul(lhs, rhs);
660663
}
661664

665+
// mulhi
666+
template <class A, class T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
667+
XSIMD_INLINE batch<T, A> mulhi(batch<T, A> const& lhs, batch<T, A> const& rhs, requires_arch<rvv>) noexcept
668+
{
669+
return detail::rvvmulh(lhs, rhs);
670+
}
671+
662672
// div
663673
template <class A, class T, typename detail::enable_arithmetic_t<T> = 0>
664674
XSIMD_INLINE batch<T, A> div(batch<T, A> const& lhs, batch<T, A> const& rhs, requires_arch<rvv>) noexcept

include/xsimd/arch/xsimd_sse2.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1438,6 +1438,18 @@ namespace xsimd
14381438
return _mm_mullo_epi16(self, other);
14391439
}
14401440

1441+
// mulhi
1442+
template <class A>
1443+
XSIMD_INLINE batch<int16_t, A> mulhi(batch<int16_t, A> const& self, batch<int16_t, A> const& other, requires_arch<sse2>) noexcept
1444+
{
1445+
return _mm_mulhi_epi16(self, other);
1446+
}
1447+
template <class A>
1448+
XSIMD_INLINE batch<uint16_t, A> mulhi(batch<uint16_t, A> const& self, batch<uint16_t, A> const& other, requires_arch<sse2>) noexcept
1449+
{
1450+
return _mm_mulhi_epu16(self, other);
1451+
}
1452+
14411453
// nearbyint_as_int
14421454
template <class A>
14431455
XSIMD_INLINE batch<int32_t, A> nearbyint_as_int(batch<float, A> const& self,

0 commit comments

Comments
 (0)