@@ -3362,16 +3362,51 @@ namespace xsimd
33623362 /* ********
33633363 * count *
33643364 *********/
3365- template <class A , class T >
3365+
3366+ // NOTE: Extracting a u32 for the return value saves two instructions on 32-bit ARM:
3367+ // <https://godbolt.org/z/PYn4na8sY>.
3368+
3369+ template <class A , class T , detail::enable_sized_t <T, 1 > = 0 >
3370+ XSIMD_INLINE size_t count (batch_bool<T, A> const & self, requires_arch<neon>) noexcept
3371+ {
3372+ uint8x16_t msbs = vshrq_n_u8 (self, 7 );
3373+ uint64x2_t psum = vpaddlq_u32 (vpaddlq_u16 (vpaddlq_u8 (msbs)));
3374+ uint64x1_t total = vadd_u64 (vget_low_u64 (psum), vget_high_u64 (psum));
3375+
3376+ assert (vget_lane_u64 (total, 0 ) <= std::numeric_limits<uint32_t >::max ());
3377+ return vget_lane_u32 (vreinterpret_u32_u64 (total), 0 );
3378+ }
3379+
3380+ template <class A , class T , detail::enable_sized_t <T, 2 > = 0 >
33663381 XSIMD_INLINE size_t count (batch_bool<T, A> const & self, requires_arch<neon>) noexcept
33673382 {
3368- uint8x16_t popcnts = vcntq_u8 (bitwise_cast< uint8_t >( bitwise_cast ( self)) );
3369- uint64x2_t psum = vpaddlq_u32 (vpaddlq_u16 (vpaddlq_u8 (popcnts) ));
3383+ uint16x8_t msbs = vshrq_n_u16 ( self, 15 );
3384+ uint64x2_t psum = vpaddlq_u32 (vpaddlq_u16 (msbs ));
33703385 uint64x1_t total = vadd_u64 (vget_low_u64 (psum), vget_high_u64 (psum));
33713386
3372- // NOTE: Extracting a u32 saves two instructions on 32-bit ARM: <https://godbolt.org/z/PYn4na8sY>.
3373- assert (vget_lane_u64 (total) <= constants::maxvalue<uint32_t >());
3374- return vget_lane_u32 (vreinterpret_u32_u64 (total), 0 ) / (sizeof (T) * 8 );
3387+ assert (vget_lane_u64 (total, 0 ) <= std::numeric_limits<uint32_t >::max ());
3388+ return vget_lane_u32 (vreinterpret_u32_u64 (total), 0 );
3389+ }
3390+
3391+ template <class A , class T , detail::enable_sized_t <T, 4 > = 0 >
3392+ XSIMD_INLINE size_t count (batch_bool<T, A> const & self, requires_arch<neon>) noexcept
3393+ {
3394+ uint32x4_t msbs = vshrq_n_u32 (self, 31 );
3395+ uint64x2_t psum = vpaddlq_u32 (msbs);
3396+ uint64x1_t total = vadd_u64 (vget_low_u64 (psum), vget_high_u64 (psum));
3397+
3398+ assert (vget_lane_u64 (total, 0 ) <= std::numeric_limits<uint32_t >::max ());
3399+ return vget_lane_u32 (vreinterpret_u32_u64 (total), 0 );
3400+ }
3401+
3402+ template <class A , class T , detail::enable_sized_t <T, 8 > = 0 >
3403+ XSIMD_INLINE size_t count (batch_bool<T, A> const & self, requires_arch<neon>) noexcept
3404+ {
3405+ uint64x2_t msbs = vshrq_n_u64 (self, 63 );
3406+ uint64x1_t total = vadd_u64 (vget_low_u64 (msbs), vget_high_u64 (msbs));
3407+
3408+ assert (vget_lane_u64 (total, 0 ) <= std::numeric_limits<uint32_t >::max ());
3409+ return vget_lane_u32 (vreinterpret_u32_u64 (total), 0 );
33753410 }
33763411
33773412#define WRAP_MASK_OP (OP ) \
0 commit comments