Skip to content

Commit fcd2dbb

Browse files
Implementation of xsimd::incr_if and xsimd::decr_if
This provides both the generic fallback and the Intel specialization. Fix #313.
1 parent 7ad7ee3 commit fcd2dbb

File tree

8 files changed

+136
-9
lines changed

8 files changed

+136
-9
lines changed

docs/source/api/batch_manip.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,7 @@ Conditional expression
2424
.. doxygenfunction:: select(batch_bool_constant<batch<T, A>, Values...> const &cond, batch<T, A> const &true_br, batch<T, A> const &false_br) noexcept
2525
:project: xsimd
2626

27+
28+
In the specific case when one needs to conditionnaly increment or decrement a
29+
batch based on a mask, :cpp:func:`incr_if` and
30+
:cpp:func:`decr_if` provide specialized version.

include/xsimd/arch/generic/xsimd_generic_arithmetic.hpp

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,13 @@ namespace xsimd
5050
return self - T(1);
5151
}
5252

53+
// decr_if
54+
template <class A, class T, class Mask>
55+
inline batch<T, A> decr_if(batch<T, A> const& self, Mask const& mask, requires_arch<generic>) noexcept
56+
{
57+
return select(mask, decr(self), self);
58+
}
59+
5360
// div
5461
template <class A, class T, class = typename std::enable_if<std::is_integral<T>::value, void>::type>
5562
inline batch<T, A> div(batch<T, A> const& self, batch<T, A> const& other, requires_arch<generic>) noexcept
@@ -111,19 +118,26 @@ namespace xsimd
111118
return -x * y - z;
112119
}
113120

121+
template <class A, class T>
122+
inline batch<std::complex<T>, A> fnms(batch<std::complex<T>, A> const& x, batch<std::complex<T>, A> const& y, batch<std::complex<T>, A> const& z, requires_arch<generic>) noexcept
123+
{
124+
auto res_r = -fms(x.real(), y.real(), fms(x.imag(), y.imag(), z.real()));
125+
auto res_i = -fma(x.real(), y.imag(), fma(x.imag(), y.real(), z.imag()));
126+
return { res_r, res_i };
127+
}
128+
114129
// incr
115130
template <class A, class T>
116131
inline batch<T, A> incr(batch<T, A> const& self, requires_arch<generic>) noexcept
117132
{
118133
return self + T(1);
119134
}
120135

121-
template <class A, class T>
122-
inline batch<std::complex<T>, A> fnms(batch<std::complex<T>, A> const& x, batch<std::complex<T>, A> const& y, batch<std::complex<T>, A> const& z, requires_arch<generic>) noexcept
136+
// incr_if
137+
template <class A, class T, class Mask>
138+
inline batch<T, A> incr_if(batch<T, A> const& self, Mask const& mask, requires_arch<generic>) noexcept
123139
{
124-
auto res_r = -fms(x.real(), y.real(), fms(x.imag(), y.imag(), z.real()));
125-
auto res_i = -fma(x.real(), y.imag(), fma(x.imag(), y.real(), z.imag()));
126-
return { res_r, res_i };
140+
return select(mask, incr(self), self);
127141
}
128142

129143
// mul

include/xsimd/arch/xsimd_avx.hpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,13 @@ namespace xsimd
549549
}
550550
}
551551

552+
// decr_if
553+
template <class A, class T, class = typename std::enable_if<std::is_integral<T>::value, void>::type>
554+
inline batch<T, A> decr_if(batch<T, A> const& self, batch_bool<T, A> const& mask, requires_arch<avx>) noexcept
555+
{
556+
return self + batch<T, A>(mask.data);
557+
}
558+
552559
// div
553560
template <class A>
554561
inline batch<float, A> div(batch<float, A> const& self, batch<float, A> const& other, requires_arch<avx>) noexcept
@@ -749,6 +756,13 @@ namespace xsimd
749756
return _mm256_add_pd(tmp1, tmp2);
750757
}
751758

759+
// incr_if
760+
template <class A, class T, class = typename std::enable_if<std::is_integral<T>::value, void>::type>
761+
inline batch<T, A> incr_if(batch<T, A> const& self, batch_bool<T, A> const& mask, requires_arch<avx>) noexcept
762+
{
763+
return self - batch<T, A>(mask.data);
764+
}
765+
752766
// insert
753767
template <class A, class T, size_t I, class = typename std::enable_if<std::is_integral<T>::value, void>::type>
754768
inline batch<T, A> insert(batch<T, A> const& self, T val, index<I> pos, requires_arch<avx>) noexcept

include/xsimd/arch/xsimd_scalar.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,12 @@ namespace xsimd
144144
return x + T(1);
145145
}
146146

147+
template <class T>
148+
inline T incr_if(T const& x, bool mask) noexcept
149+
{
150+
return x + T(mask ? 1 : 0);
151+
}
152+
147153
inline bool all(bool mask)
148154
{
149155
return mask;
@@ -765,6 +771,12 @@ namespace xsimd
765771
return x - T(1);
766772
}
767773

774+
template <class T>
775+
inline T decr_if(T const& x, bool mask) noexcept
776+
{
777+
return x - T(mask ? 1 : 0);
778+
}
779+
768780
#ifdef XSIMD_ENABLE_XTL_COMPLEX
769781
template <class T, bool i3ec>
770782
inline xtl::xcomplex<T, T, i3ec> log2(const xtl::xcomplex<T, T, i3ec>& val) noexcept

include/xsimd/arch/xsimd_sse2.hpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,13 @@ namespace xsimd
501501
}
502502
}
503503

504+
// decr_if
505+
template <class A, class T, class = typename std::enable_if<std::is_integral<T>::value, void>::type>
506+
inline batch<T, A> decr_if(batch<T, A> const& self, batch_bool<T, A> const& mask, requires_arch<sse2>) noexcept
507+
{
508+
return self + batch<T, A>(mask.data);
509+
}
510+
504511
// div
505512
template <class A>
506513
inline batch<float, A> div(batch<float, A> const& self, batch<float, A> const& other, requires_arch<sse2>) noexcept
@@ -808,6 +815,13 @@ namespace xsimd
808815
_mm_unpackhi_pd(row[0], row[1]));
809816
}
810817

818+
// incr_if
819+
template <class A, class T, class = typename std::enable_if<std::is_integral<T>::value, void>::type>
820+
inline batch<T, A> incr_if(batch<T, A> const& self, batch_bool<T, A> const& mask, requires_arch<sse2>) noexcept
821+
{
822+
return self - batch<T, A>(mask.data);
823+
}
824+
811825
// insert
812826
template <class A, class T, size_t I, class = typename std::enable_if<std::is_integral<T>::value, void>::type>
813827
inline batch<T, A> insert(batch<T, A> const& self, T val, index<I> pos, requires_arch<sse2>) noexcept

include/xsimd/types/xsimd_api.hpp

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -549,9 +549,9 @@ namespace xsimd
549549
/**
550550
* @ingroup batch_arithmetic
551551
*
552-
* Subtract 1 to batch \c x
552+
* Subtract 1 to batch \c x.
553553
* @param x batch involved in the decrement.
554-
* @return the substraction of \c x and 1
554+
* @return the subtraction of \c x and 1.
555555
*/
556556
template <class T, class A>
557557
inline batch<T, A> decr(batch<T, A> const& x) noexcept
@@ -560,6 +560,22 @@ namespace xsimd
560560
return kernel::decr<A>(x, A {});
561561
}
562562

563+
/**
564+
* @ingroup batch_arithmetic
565+
*
566+
* Subtract 1 to batch \c x for each element where \c mask is true.
567+
* @param x batch involved in the increment.
568+
* @param mask whether to perform the increment or not. Can be a \c
569+
* batch_bool or a \c batch_bool_constant.
570+
* @return the subtraction of \c x and 1 when \c mask is true.
571+
*/
572+
template <class T, class A, class Mask>
573+
inline batch<T, A> decr_if(batch<T, A> const& x, Mask const& mask) noexcept
574+
{
575+
detail::static_check_supported_config<T, A>();
576+
return kernel::decr_if<A>(x, mask, A {});
577+
}
578+
563579
/**
564580
* @ingroup batch_arithmetic
565581
*
@@ -941,9 +957,9 @@ namespace xsimd
941957
/**
942958
* @ingroup batch_arithmetic
943959
*
944-
* Add 1 to batch \c x
960+
* Add 1 to batch \c x.
945961
* @param x batch involved in the increment.
946-
* @return the sum of \c x and 1
962+
* @return the sum of \c x and 1.
947963
*/
948964
template <class T, class A>
949965
inline batch<T, A> incr(batch<T, A> const& x) noexcept
@@ -952,6 +968,22 @@ namespace xsimd
952968
return kernel::incr<A>(x, A {});
953969
}
954970

971+
/**
972+
* @ingroup batch_arithmetic
973+
*
974+
* Add 1 to batch \c x for each element where \c mask is true.
975+
* @param x batch involved in the increment.
976+
* @param mask whether to perform the increment or not. Can be a \c
977+
* batch_bool or a \c batch_bool_constant.
978+
* @return the sum of \c x and 1 when \c mask is true.
979+
*/
980+
template <class T, class A, class Mask>
981+
inline batch<T, A> incr_if(batch<T, A> const& x, Mask const& mask) noexcept
982+
{
983+
detail::static_check_supported_config<T, A>();
984+
return kernel::incr_if<A>(x, mask, A {});
985+
}
986+
955987
/**
956988
* @ingroup batch_constant
957989
*

test/test_batch.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,18 @@ struct batch_test
258258
INFO("incr(batch)");
259259
CHECK_BATCH_EQ(res, expected);
260260
}
261+
262+
// incr_if
263+
{
264+
array_type expected;
265+
std::transform(lhs.cbegin(), lhs.cend(), expected.begin(),
266+
[](value_type v)
267+
{ return v > 1 ? v + 1 : v; });
268+
batch_type res = xsimd::incr_if(batch_lhs(), batch_lhs() > value_type(1));
269+
INFO("incr_if(batch)");
270+
CHECK_BATCH_EQ(res, expected);
271+
}
272+
261273
// decr
262274
{
263275
array_type expected;
@@ -266,6 +278,17 @@ struct batch_test
266278
INFO("decr(batch)");
267279
CHECK_BATCH_EQ(res, expected);
268280
}
281+
282+
// decr_if
283+
{
284+
array_type expected;
285+
std::transform(lhs.cbegin(), lhs.cend(), expected.begin(),
286+
[](value_type v)
287+
{ return v > 1 ? v - 1 : v; });
288+
batch_type res = xsimd::decr_if(batch_lhs(), batch_lhs() > value_type(1));
289+
INFO("decr_if(batch)");
290+
CHECK_BATCH_EQ(res, expected);
291+
}
269292
}
270293

271294
void test_saturated_arithmetic() const

test/test_xsimd_api.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,6 +1089,12 @@ struct xsimd_api_all_types_functions
10891089
CHECK_EQ(extract(xsimd::decr(T(val0))), val0 - value_type(1));
10901090
}
10911091

1092+
void test_decr_if()
1093+
{
1094+
value_type val0(1);
1095+
CHECK_EQ(extract(xsimd::decr_if(T(val0), T(val0) != T(0))), val0 - value_type(1));
1096+
}
1097+
10921098
void test_div()
10931099
{
10941100
value_type val0(1);
@@ -1133,6 +1139,12 @@ struct xsimd_api_all_types_functions
11331139
CHECK_EQ(extract(xsimd::incr(T(val0))), val0 + value_type(1));
11341140
}
11351141

1142+
void test_incr_if()
1143+
{
1144+
value_type val0(1);
1145+
CHECK_EQ(extract(xsimd::incr_if(T(val0), T(val0) != T(0))), val0 + value_type(1));
1146+
}
1147+
11361148
void test_mul()
11371149
{
11381150
value_type val0(2);
@@ -1176,6 +1188,7 @@ TEST_CASE_TEMPLATE("[xsimd api | all types functions]", B, ALL_TYPES)
11761188
SUBCASE("decr")
11771189
{
11781190
Test.test_decr();
1191+
Test.test_decr_if();
11791192
}
11801193

11811194
SUBCASE("div")
@@ -1206,6 +1219,7 @@ TEST_CASE_TEMPLATE("[xsimd api | all types functions]", B, ALL_TYPES)
12061219
SUBCASE("incr")
12071220
{
12081221
Test.test_incr();
1222+
Test.test_incr_if();
12091223
}
12101224

12111225
SUBCASE("mul")

0 commit comments

Comments
 (0)