Skip to content

Commit 9f63efb

Browse files
Implement reduc_min and reduce_max
Using a generic reducer, aka 'butterfly reduction'. As a side effect, fix a bug in (untested until then) SSSE3 swizzle implementation for int8 and int16. Fix #219 (from 2018 ^^!)
1 parent 9dce801 commit 9f63efb

File tree

12 files changed

+192
-25
lines changed

12 files changed

+192
-25
lines changed

include/xsimd/arch/generic/xsimd_generic_math.hpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1976,6 +1976,49 @@ namespace xsimd
19761976
return { reduce_add(self.real()), reduce_add(self.imag()) };
19771977
}
19781978

1979+
namespace detail
1980+
{
1981+
template <class T, T N>
1982+
struct SplitHigh
1983+
{
1984+
static constexpr T get(T i, T)
1985+
{
1986+
return i >= N ? 0 : i + N;
1987+
}
1988+
};
1989+
1990+
template <class Op, class A, class T>
1991+
inline T reduce(Op, batch<T, A> const& self, std::integral_constant<unsigned, 1>) noexcept
1992+
{
1993+
return self.get(0);
1994+
}
1995+
1996+
template <class Op, class A, class T, unsigned Lvl>
1997+
inline T reduce(Op op, batch<T, A> const& self, std::integral_constant<unsigned, Lvl>) noexcept
1998+
{
1999+
using index_type = as_unsigned_integer_t<T>;
2000+
batch<T, A> split = swizzle(self, make_batch_constant<batch<index_type, A>, SplitHigh<index_type, Lvl / 2>>());
2001+
return reduce(op, op(split, self), std::integral_constant<unsigned, Lvl / 2>());
2002+
}
2003+
}
2004+
2005+
// reduce_max
2006+
template <class A, class T>
2007+
inline T reduce_max(batch<T, A> const& self, requires_arch<generic>) noexcept
2008+
{
2009+
return detail::reduce([](batch<T, A> const& x, batch<T, A> const& y)
2010+
{ return max(x, y); },
2011+
self, std::integral_constant<unsigned, batch<T, A>::size>());
2012+
}
2013+
2014+
// reduce_min
2015+
template <class A, class T>
2016+
inline T reduce_min(batch<T, A> const& self, requires_arch<generic>) noexcept
2017+
{
2018+
return detail::reduce([](batch<T, A> const& x, batch<T, A> const& y)
2019+
{ return min(x, y); },
2020+
self, std::integral_constant<unsigned, batch<T, A>::size>());
2021+
}
19792022

19802023
// remainder
19812024
template <class A>

include/xsimd/arch/xsimd_avx.hpp

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,28 @@ namespace xsimd
10981098
return reduce_add(blow) + reduce_add(bhigh);
10991099
}
11001100

1101+
// reduce_max
1102+
template <class A, class T, class _ = typename std::enable_if<(sizeof(T) <= 2), void>::type>
1103+
inline T reduce_max(batch<T, A> const& self, requires_arch<avx>) noexcept
1104+
{
1105+
constexpr auto mask = detail::shuffle(1, 0);
1106+
batch<T, A> step = _mm256_permute2f128_si256(self, self, mask);
1107+
batch<T, A> acc = max(self, step);
1108+
__m128i low = _mm256_castsi256_si128(acc);
1109+
return reduce_max(batch<T, sse4_2>(low));
1110+
}
1111+
1112+
// reduce_min
1113+
template <class A, class T, class _ = typename std::enable_if<(sizeof(T) <= 2), void>::type>
1114+
inline T reduce_min(batch<T, A> const& self, requires_arch<avx>) noexcept
1115+
{
1116+
constexpr auto mask = detail::shuffle(1, 0);
1117+
batch<T, A> step = _mm256_permute2f128_si256(self, self, mask);
1118+
batch<T, A> acc = min(self, step);
1119+
__m128i low = _mm256_castsi256_si128(acc);
1120+
return reduce_min(batch<T, sse4_2>(low));
1121+
}
1122+
11011123
// rsqrt
11021124
template <class A>
11031125
inline batch<float, A> rsqrt(batch<float, A> const& val, requires_arch<avx>) noexcept
@@ -1499,12 +1521,13 @@ namespace xsimd
14991521
return bitwise_cast<batch<T, A>>(
15001522
swizzle(bitwise_cast<batch<float, A>>(self), mask));
15011523
}
1524+
15021525
template <class A,
15031526
typename T,
1504-
uint32_t V0,
1505-
uint32_t V1,
1506-
uint32_t V2,
1507-
uint32_t V3,
1527+
uint64_t V0,
1528+
uint64_t V1,
1529+
uint64_t V2,
1530+
uint64_t V3,
15081531
detail::enable_sized_integral_t<T, 8> = 0>
15091532
inline batch<T, A>
15101533
swizzle(batch<T, A> const& self,

include/xsimd/arch/xsimd_avx512bw.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ namespace xsimd
557557
template <class A, uint8_t... Vs>
558558
inline batch<uint8_t, A> swizzle(batch<uint8_t, A> const& self, batch_constant<batch<uint8_t, A>, Vs...> mask, requires_arch<avx512bw>) noexcept
559559
{
560-
return _mm512_permutexvar_epi8((batch<uint8_t, A>)mask, self);
560+
return _mm512_shuffle_epi8(self, (batch<uint8_t, A>)mask);
561561
}
562562

563563
template <class A, uint8_t... Vs>

include/xsimd/arch/xsimd_avx512f.hpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,6 +1299,27 @@ namespace xsimd
12991299
return reduce_add(blow, avx2 {}) + reduce_add(bhigh, avx2 {});
13001300
}
13011301

1302+
// reduce_max
1303+
template <class A, class T, class _ = typename std::enable_if<(sizeof(T) == 1), void>::type>
1304+
inline T reduce_max(batch<T, A> const& self, requires_arch<avx512f>) noexcept
1305+
{
1306+
constexpr batch_constant<batch<uint64_t, A>, 5, 6, 7, 8, 0, 0, 0, 0> mask;
1307+
batch<T, A> step = _mm512_permutexvar_epi64((batch<uint64_t, A>)mask, self);
1308+
batch<T, A> acc = max(self, step);
1309+
__m256i low = _mm512_castsi512_si256(acc);
1310+
return reduce_max(batch<T, avx2>(low));
1311+
}
1312+
1313+
// reduce_min
1314+
template <class A, class T, class _ = typename std::enable_if<(sizeof(T) == 1), void>::type>
1315+
inline T reduce_min(batch<T, A> const& self, requires_arch<avx512f>) noexcept
1316+
{
1317+
constexpr batch_constant<batch<uint64_t, A>, 5, 6, 7, 8, 0, 0, 0, 0> mask;
1318+
batch<T, A> step = _mm512_permutexvar_epi64((batch<uint64_t, A>)mask, self);
1319+
batch<T, A> acc = min(self, step);
1320+
__m256i low = _mm512_castsi512_si256(acc);
1321+
return reduce_min(batch<T, avx2>(low));
1322+
}
13021323

13031324
// rsqrt
13041325
template <class A>

include/xsimd/arch/xsimd_generic_fwd.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
#ifndef XSIMD_GENERIC_FWD_HPP
1313
#define XSIMD_GENERIC_FWD_HPP
1414

15+
#include "../types/xsimd_batch_constant.hpp"
16+
1517
#include <type_traits>
1618

1719
namespace xsimd
1820
{
19-
2021
namespace kernel
2122
{
2223
// forward declaration

include/xsimd/arch/xsimd_neon.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1577,7 +1577,6 @@ namespace xsimd
15771577
return vget_lane_f32(tmp, 0);
15781578
}
15791579

1580-
15811580
/**********
15821581
* select *
15831582
**********/

include/xsimd/arch/xsimd_neon64.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,6 @@ namespace xsimd
753753
return vaddvq_f64(arg);
754754
}
755755

756-
757756
/**********
758757
* select *
759758
**********/

include/xsimd/arch/xsimd_sse2.hpp

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,18 @@ namespace xsimd
3333
{
3434
using namespace types;
3535

36+
namespace detail
37+
{
38+
constexpr uint32_t shuffle(uint32_t w, uint32_t x, uint32_t y, uint32_t z)
39+
{
40+
return (z << 6) | (y << 4) | (x << 2) | w;
41+
}
42+
constexpr uint32_t shuffle(uint32_t x, uint32_t y)
43+
{
44+
return (y << 1) | x;
45+
}
46+
}
47+
3648
// fwd
3749
template <class A, class T, size_t I>
3850
inline batch<T, A> insert(batch<T, A> const& self, T val, index<I>, requires_arch<generic>) noexcept;
@@ -1155,6 +1167,50 @@ namespace xsimd
11551167
__m128 tmp1 = _mm_add_ss(tmp0, _mm_shuffle_ps(tmp0, tmp0, 1));
11561168
return _mm_cvtss_f32(tmp1);
11571169
}
1170+
1171+
// reduce_max
1172+
template <class A, class T, class _ = typename std::enable_if<(sizeof(T) <= 2), void>::type>
1173+
inline T reduce_max(batch<T, A> const& self, requires_arch<sse2>) noexcept
1174+
{
1175+
constexpr auto mask0 = detail::shuffle(2, 3, 0, 0);
1176+
batch<T, A> step0 = _mm_shuffle_epi32(self, mask0);
1177+
batch<T, A> acc0 = max(self, step0);
1178+
1179+
constexpr auto mask1 = detail::shuffle(1, 0, 0, 0);
1180+
batch<T, A> step1 = _mm_shuffle_epi32(acc0, mask1);
1181+
batch<T, A> acc1 = max(acc0, step1);
1182+
1183+
constexpr auto mask2 = detail::shuffle(1, 0, 0, 0);
1184+
batch<T, A> step2 = _mm_shufflelo_epi16(acc1, mask2);
1185+
batch<T, A> acc2 = max(acc1, step2);
1186+
if (sizeof(T) == 2)
1187+
return acc2.get(0);
1188+
batch<T, A> step3 = bitwise_cast<batch<T, A>>(bitwise_cast<batch<uint16_t, A>>(acc2) >> 8);
1189+
batch<T, A> acc3 = max(acc2, step3);
1190+
return acc3.get(0);
1191+
}
1192+
1193+
// reduce_min
1194+
template <class A, class T, class _ = typename std::enable_if<(sizeof(T) <= 2), void>::type>
1195+
inline T reduce_min(batch<T, A> const& self, requires_arch<sse2>) noexcept
1196+
{
1197+
constexpr auto mask0 = detail::shuffle(2, 3, 0, 0);
1198+
batch<T, A> step0 = _mm_shuffle_epi32(self, mask0);
1199+
batch<T, A> acc0 = min(self, step0);
1200+
1201+
constexpr auto mask1 = detail::shuffle(1, 0, 0, 0);
1202+
batch<T, A> step1 = _mm_shuffle_epi32(acc0, mask1);
1203+
batch<T, A> acc1 = min(acc0, step1);
1204+
1205+
constexpr auto mask2 = detail::shuffle(1, 0, 0, 0);
1206+
batch<T, A> step2 = _mm_shufflelo_epi16(acc1, mask2);
1207+
batch<T, A> acc2 = min(acc1, step2);
1208+
if (sizeof(T) == 2)
1209+
return acc2.get(0);
1210+
batch<T, A> step3 = bitwise_cast<batch<T, A>>(bitwise_cast<batch<uint16_t, A>>(acc2) >> 8);
1211+
batch<T, A> acc3 = min(acc2, step3);
1212+
return acc3.get(0);
1213+
}
11581214
// TODO: move this in xsimd_generic
11591215
namespace detail
11601216
{
@@ -1207,7 +1263,6 @@ namespace xsimd
12071263
return _mm_cvtsd_f64(_mm_add_sd(self, _mm_unpackhi_pd(self, self)));
12081264
}
12091265

1210-
12111266
// rsqrt
12121267
template <class A>
12131268
inline batch<float, A> rsqrt(batch<float, A> const& val, requires_arch<sse2>) noexcept
@@ -1541,18 +1596,6 @@ namespace xsimd
15411596

15421597
// swizzle
15431598

1544-
namespace detail
1545-
{
1546-
constexpr uint32_t shuffle(uint32_t w, uint32_t x, uint32_t y, uint32_t z)
1547-
{
1548-
return (z << 6) | (y << 4) | (x << 2) | w;
1549-
}
1550-
constexpr uint32_t shuffle(uint32_t x, uint32_t y)
1551-
{
1552-
return (y << 1) | x;
1553-
}
1554-
}
1555-
15561599
template <class A, uint32_t V0, uint32_t V1, uint32_t V2, uint32_t V3>
15571600
inline batch<float, A> swizzle(batch<float, A> const& self, batch_constant<batch<uint32_t, A>, V0, V1, V2, V3>, requires_arch<sse2>) noexcept
15581601
{

include/xsimd/arch/xsimd_sse3.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ namespace xsimd
4242
return _mm_lddqu_si128((__m128i const*)mem);
4343
}
4444

45-
4645
// reduce_add
4746
template <class A>
4847
inline float reduce_add(batch<float, A> const& self, requires_arch<sse3>) noexcept
@@ -58,7 +57,6 @@ namespace xsimd
5857
return _mm_cvtsd_f64(tmp0);
5958
}
6059

61-
6260
}
6361

6462
}

include/xsimd/arch/xsimd_ssse3.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ namespace xsimd
118118
template <class A, uint16_t V0, uint16_t V1, uint16_t V2, uint16_t V3, uint16_t V4, uint16_t V5, uint16_t V6, uint16_t V7>
119119
inline batch<int16_t, A> swizzle(batch<int16_t, A> const& self, batch_constant<batch<uint16_t, A>, V0, V1, V2, V3, V4, V5, V6, V7> mask, requires_arch<ssse3>) noexcept
120120
{
121-
return bitwise_cast<int16_t>(swizzle(bitwise_cast<uint16_t>(self), mask, ssse3 {}));
121+
return bitwise_cast<batch<int16_t, A>>(swizzle(bitwise_cast<batch<uint16_t, A>>(self), mask, ssse3 {}));
122122
}
123123

124124
template <class A, uint8_t V0, uint8_t V1, uint8_t V2, uint8_t V3, uint8_t V4, uint8_t V5, uint8_t V6, uint8_t V7,
@@ -132,7 +132,7 @@ namespace xsimd
132132
uint8_t V8, uint8_t V9, uint8_t V10, uint8_t V11, uint8_t V12, uint8_t V13, uint8_t V14, uint8_t V15>
133133
inline batch<int8_t, A> swizzle(batch<int8_t, A> const& self, batch_constant<batch<uint8_t, A>, V0, V1, V2, V3, V4, V5, V6, V7, V8, V9, V10, V11, V12, V13, V14, V15> mask, requires_arch<ssse3>) noexcept
134134
{
135-
return bitwise_cast<int8_t>(swizzle(bitwise_cast<uint8_t>(self), mask, ssse3 {}));
135+
return bitwise_cast<batch<int8_t, A>>(swizzle(bitwise_cast<batch<uint8_t, A>>(self), mask, ssse3 {}));
136136
}
137137

138138
}

0 commit comments

Comments
 (0)