Skip to content

Commit 29a0bd1

Browse files
Provide and test comparison operators overloads that accept scalars
Instead of writing the actual scalar overloads, rely on the nice property of friend operators defined inline that allow for implicit conversion. Fix #229
1 parent 73c90a2 commit 29a0bd1

File tree

2 files changed

+123
-60
lines changed

2 files changed

+123
-60
lines changed

include/xsimd/types/xsimd_batch.hpp

Lines changed: 107 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ namespace xsimd
2323
{
2424
template <class T, class A = default_arch>
2525
class batch;
26+
2627
namespace types
2728
{
2829
template <class T, class A>
@@ -75,6 +76,30 @@ namespace xsimd
7576

7677
}
7778

79+
namespace details
80+
{
81+
// These functions are forwarded declared here so that they can be used by friend functions
82+
// with batch<T, A>. Their implementation must appear only once the
83+
// kernel implementations have been included.
84+
template <class T, class A>
85+
inline batch_bool<T, A> eq(batch<T, A> const& self, batch<T, A> const& other) noexcept;
86+
87+
template <class T, class A>
88+
inline batch_bool<T, A> neq(batch<T, A> const& self, batch<T, A> const& other) noexcept;
89+
90+
template <class T, class A>
91+
inline batch_bool<T, A> ge(batch<T, A> const& self, batch<T, A> const& other) noexcept;
92+
93+
template <class T, class A>
94+
inline batch_bool<T, A> le(batch<T, A> const& self, batch<T, A> const& other) noexcept;
95+
96+
template <class T, class A>
97+
inline batch_bool<T, A> gt(batch<T, A> const& self, batch<T, A> const& other) noexcept;
98+
99+
template <class T, class A>
100+
inline batch_bool<T, A> lt(batch<T, A> const& self, batch<T, A> const& other) noexcept;
101+
}
102+
78103
/**
79104
* @brief batch of integer or floating point values.
80105
*
@@ -133,13 +158,33 @@ namespace xsimd
133158

134159
T get(std::size_t i) const noexcept;
135160

136-
// comparison operators
137-
inline batch_bool_type operator==(batch const& other) const noexcept;
138-
inline batch_bool_type operator!=(batch const& other) const noexcept;
139-
inline batch_bool_type operator>=(batch const& other) const noexcept;
140-
inline batch_bool_type operator<=(batch const& other) const noexcept;
141-
inline batch_bool_type operator>(batch const& other) const noexcept;
142-
inline batch_bool_type operator<(batch const& other) const noexcept;
161+
// comparison operators. Defined as friend to enable automatic
162+
// conversion of parameters from scalar to batch, at the cost of using a
163+
// proxy implementation from details::.
164+
friend batch_bool<T, A> operator==(batch const& self, batch const& other) noexcept
165+
{
166+
return details::eq<T, A>(self, other);
167+
}
168+
friend batch_bool<T, A> operator!=(batch const& self, batch const& other) noexcept
169+
{
170+
return details::neq<T, A>(self, other);
171+
}
172+
friend batch_bool<T, A> operator>=(batch const& self, batch const& other) noexcept
173+
{
174+
return details::ge<T, A>(self, other);
175+
}
176+
friend batch_bool<T, A> operator<=(batch const& self, batch const& other) noexcept
177+
{
178+
return details::le<T, A>(self, other);
179+
}
180+
friend batch_bool<T, A> operator>(batch const& self, batch const& other) noexcept
181+
{
182+
return details::gt<T, A>(self, other);
183+
}
184+
friend batch_bool<T, A> operator<(batch const& self, batch const& other) noexcept
185+
{
186+
return details::lt<T, A>(self, other);
187+
}
143188

144189
// Update operators
145190
inline batch& operator+=(batch const& other) noexcept;
@@ -650,65 +695,67 @@ namespace xsimd
650695
/******************************
651696
* batch comparison operators *
652697
******************************/
653-
654-
/**
655-
* Shorthand for xsimd::eq()
656-
*/
657-
template <class T, class A>
658-
inline batch_bool<T, A> batch<T, A>::operator==(batch<T, A> const& other) const noexcept
698+
namespace details
659699
{
660-
detail::static_check_supported_config<T, A>();
661-
return kernel::eq<A>(*this, other, A {});
662-
}
700+
/**
701+
* Shorthand for xsimd::eq()
702+
*/
703+
template <class T, class A>
704+
inline batch_bool<T, A> eq(batch<T, A> const& self, batch<T, A> const& other) noexcept
705+
{
706+
detail::static_check_supported_config<T, A>();
707+
return kernel::eq<A>(self, other, A {});
708+
}
663709

664-
/**
665-
* Shorthand for xsimd::neq()
666-
*/
667-
template <class T, class A>
668-
inline batch_bool<T, A> batch<T, A>::operator!=(batch<T, A> const& other) const noexcept
669-
{
670-
detail::static_check_supported_config<T, A>();
671-
return kernel::neq<A>(*this, other, A {});
672-
}
710+
/**
711+
* Shorthand for xsimd::neq()
712+
*/
713+
template <class T, class A>
714+
inline batch_bool<T, A> neq(batch<T, A> const& self, batch<T, A> const& other) noexcept
715+
{
716+
detail::static_check_supported_config<T, A>();
717+
return kernel::neq<A>(self, other, A {});
718+
}
673719

674-
/**
675-
* Shorthand for xsimd::ge()
676-
*/
677-
template <class T, class A>
678-
inline batch_bool<T, A> batch<T, A>::operator>=(batch<T, A> const& other) const noexcept
679-
{
680-
detail::static_check_supported_config<T, A>();
681-
return kernel::ge<A>(*this, other, A {});
682-
}
720+
/**
721+
* Shorthand for xsimd::ge()
722+
*/
723+
template <class T, class A>
724+
inline batch_bool<T, A> ge(batch<T, A> const& self, batch<T, A> const& other) noexcept
725+
{
726+
detail::static_check_supported_config<T, A>();
727+
return kernel::ge<A>(self, other, A {});
728+
}
683729

684-
/**
685-
* Shorthand for xsimd::le()
686-
*/
687-
template <class T, class A>
688-
inline batch_bool<T, A> batch<T, A>::operator<=(batch<T, A> const& other) const noexcept
689-
{
690-
detail::static_check_supported_config<T, A>();
691-
return kernel::le<A>(*this, other, A {});
692-
}
730+
/**
731+
* Shorthand for xsimd::le()
732+
*/
733+
template <class T, class A>
734+
inline batch_bool<T, A> le(batch<T, A> const& self, batch<T, A> const& other) noexcept
735+
{
736+
detail::static_check_supported_config<T, A>();
737+
return kernel::le<A>(self, other, A {});
738+
}
693739

694-
/**
695-
* Shorthand for xsimd::gt()
696-
*/
697-
template <class T, class A>
698-
inline batch_bool<T, A> batch<T, A>::operator>(batch<T, A> const& other) const noexcept
699-
{
700-
detail::static_check_supported_config<T, A>();
701-
return kernel::gt<A>(*this, other, A {});
702-
}
740+
/**
741+
* Shorthand for xsimd::gt()
742+
*/
743+
template <class T, class A>
744+
inline batch_bool<T, A> gt(batch<T, A> const& self, batch<T, A> const& other) noexcept
745+
{
746+
detail::static_check_supported_config<T, A>();
747+
return kernel::gt<A>(self, other, A {});
748+
}
703749

704-
/**
705-
* Shorthand for xsimd::lt()
706-
*/
707-
template <class T, class A>
708-
inline batch_bool<T, A> batch<T, A>::operator<(batch<T, A> const& other) const noexcept
709-
{
710-
detail::static_check_supported_config<T, A>();
711-
return kernel::lt<A>(*this, other, A {});
750+
/**
751+
* Shorthand for xsimd::lt()
752+
*/
753+
template <class T, class A>
754+
inline batch_bool<T, A> lt(batch<T, A> const& self, batch<T, A> const& other) noexcept
755+
{
756+
detail::static_check_supported_config<T, A>();
757+
return kernel::lt<A>(self, other, A {});
758+
}
712759
}
713760

714761
/**************************

test/test_batch.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,10 @@ struct batch_test
437437
auto res = batch_lhs() < scalar;
438438
INFO("batch < scalar");
439439
CHECK_BATCH_EQ(res, expected);
440+
441+
auto res_neg = batch_lhs() >= scalar;
442+
INFO("batch >= scalar");
443+
CHECK_BATCH_EQ(!res_neg, expected);
440444
}
441445

442446
// batch <= batch
@@ -458,6 +462,10 @@ struct batch_test
458462
auto res = batch_lhs() <= scalar;
459463
INFO("batch <= scalar");
460464
CHECK_BATCH_EQ(res, expected);
465+
466+
auto res_neg = batch_lhs() > scalar;
467+
INFO("batch > scalar");
468+
CHECK_BATCH_EQ(!res_neg, expected);
461469
}
462470

463471
// batch > batch
@@ -479,6 +487,10 @@ struct batch_test
479487
auto res = batch_lhs() > scalar;
480488
INFO("batch > scalar");
481489
CHECK_BATCH_EQ(res, expected);
490+
491+
auto res_neg = batch_lhs() <= scalar;
492+
INFO("batch <= scalar");
493+
CHECK_BATCH_EQ(!res_neg, expected);
482494
}
483495
// batch >= batch
484496
{
@@ -499,6 +511,10 @@ struct batch_test
499511
auto res = batch_lhs() >= scalar;
500512
INFO("batch >= scalar");
501513
CHECK_BATCH_EQ(res, expected);
514+
515+
auto res_neg = batch_lhs() < scalar;
516+
INFO("batch < scalar");
517+
CHECK_BATCH_EQ(!res_neg, expected);
502518
}
503519
}
504520

0 commit comments

Comments
 (0)