@@ -83,12 +83,12 @@ use core::{
8383
8484#[ cfg( feature = "simd" ) ]
8585use core:: {
86- ops:: { AddAssign , Mul } ,
86+ ops:: { AddAssign , Mul , SubAssign } ,
8787 ptr,
8888 simd:: {
89+ LaneCount , Mask , Simd , SimdElement , SupportedLaneCount ,
8990 cmp:: { SimdPartialEq , SimdPartialOrd } ,
9091 ptr:: SimdConstPtr ,
91- LaneCount , Mask , Simd , SimdElement , SupportedLaneCount ,
9292 } ,
9393} ;
9494
@@ -192,15 +192,30 @@ pub trait Axis: PartialOrd + Copy + Sub<Output = Self> + Add<Output = Self> {
192192 fn square ( self ) -> Self ;
193193}
194194
195+ #[ cfg( feature = "simd" ) ]
196+ /// A trait used for SIMD elements.
197+ pub trait AxisSimdElement : SimdElement + Default + Axis { }
198+
195199#[ cfg( feature = "simd" ) ]
196200/// A trait used for masks over SIMD vectors, used for parallel querying on [`Capt`]s.
197201///
198202/// The interface for this trait should be considered unstable since the standard SIMD API may
199203/// change with Rust versions.
200- pub trait AxisSimd < M > : SimdElement + Default {
201- #[ must_use]
202- /// Determine whether any element of this mask is set to `true`.
203- fn any ( mask : M ) -> bool ;
204+ pub trait AxisSimd < const L : usize > :
205+ Sized
206+ + SimdPartialOrd
207+ + Add < Output = Self >
208+ + AddAssign
209+ + Sub < Output = Self >
210+ + SubAssign
211+ + Mul < Output = Self >
212+ where
213+ LaneCount < L > : SupportedLaneCount ,
214+ {
215+ /// Cast a mask for a SIMD vector into a mask of `isize`s.
216+ fn cast_mask ( mask : <Self as SimdPartialEq >:: Mask ) -> Mask < isize , L > ;
217+ /// Determine whether a mask contains any true elements.
218+ fn mask_any ( mask : <Self as SimdPartialEq >:: Mask ) -> bool ;
204219}
205220
206221/// An index type used for lookups into and out of arrays.
@@ -251,12 +266,18 @@ macro_rules! impl_axis {
251266 }
252267
253268 #[ cfg( feature = "simd" ) ]
254- impl <const L : usize > AxisSimd <Mask <$tm, L >> for $t
269+ impl AxisSimdElement for $t { }
270+
271+ #[ cfg( feature = "simd" ) ]
272+ impl <const L : usize > AxisSimd <L > for Simd <$t, L >
255273 where
256274 LaneCount <L >: SupportedLaneCount ,
257275 {
258- fn any( mask: Mask <$tm, L >) -> bool {
259- Mask :: <$tm, L >:: any( mask)
276+ fn cast_mask( mask: <Self as SimdPartialEq >:: Mask ) -> Mask <isize , L > {
277+ mask. into( )
278+ }
279+ fn mask_any( mask: <Self as SimdPartialEq >:: Mask ) -> bool {
280+ mask. any( )
260281 }
261282 }
262283 } ;
@@ -308,17 +329,17 @@ fn forward_pass_simd<A, const K: usize, const L: usize>(
308329 centers : & [ Simd < A , L > ; K ] ,
309330) -> Simd < isize , L >
310331where
311- Simd < A , L > : SimdPartialOrd ,
312- Mask < isize , L > : From < <Simd < A , L > as SimdPartialEq >:: Mask > ,
313- A : Axis + AxisSimd < <Simd < A , L > as SimdPartialEq >:: Mask > ,
332+ Simd < A , L > : AxisSimd < L > ,
333+ A : AxisSimdElement ,
314334 LaneCount < L > : SupportedLaneCount ,
315335{
316336 let mut test_idxs: Simd < isize , L > = Simd :: splat ( 0 ) ;
317337 let mut k = 0 ;
318338 for _ in 0 ..tests. len ( ) . trailing_ones ( ) {
319339 let test_ptrs = Simd :: splat ( tests. as_ptr ( ) ) . wrapping_offset ( test_idxs) ;
320340 let relevant_tests: Simd < A , L > = unsafe { Simd :: gather_ptr ( test_ptrs) } ;
321- let cmp_results: Mask < isize , L > = centers[ k % K ] . simd_ge ( relevant_tests) . into ( ) ;
341+ let cmp_results: Mask < isize , L > =
342+ Simd :: < A , L > :: cast_mask ( centers[ k % K ] . simd_ge ( relevant_tests) ) ;
322343
323344 let one = Simd :: splat ( 1 ) ;
324345 test_idxs = ( test_idxs << one) + one + ( cmp_results. to_int ( ) & Simd :: splat ( 1 ) ) ;
@@ -884,10 +905,8 @@ where
884905 pub fn collides_simd ( & self , centers : & [ Simd < A , L > ; K ] , mut radii : Simd < A , L > ) -> bool
885906 where
886907 LaneCount < L > : SupportedLaneCount ,
887- Simd < A , L > :
888- SimdPartialOrd + Sub < Output = Simd < A , L > > + Mul < Output = Simd < A , L > > + AddAssign ,
889- Mask < isize , L > : From < <Simd < A , L > as SimdPartialEq >:: Mask > ,
890- A : Axis + AxisSimd < <Simd < A , L > as SimdPartialEq >:: Mask > ,
908+ Simd < A , L > : AxisSimd < L > ,
909+ A : AxisSimdElement ,
891910 {
892911 radii += Simd :: splat ( self . r_point ) ;
893912 let zs = forward_pass_simd ( & self . tests , centers) ;
@@ -898,15 +917,15 @@ where
898917
899918 unsafe {
900919 for center in centers {
901- inbounds &= Mask :: < isize , L > :: from (
920+ inbounds &= Simd :: < A , L > :: cast_mask (
902921 ( Simd :: gather_select_ptr ( aabb_ptrs, inbounds, Simd :: splat ( A :: NEG_INFINITY ) )
903922 - radii)
904923 . simd_le ( * center) ,
905924 ) ;
906925 aabb_ptrs = aabb_ptrs. wrapping_add ( Simd :: splat ( 1 ) ) ;
907926 }
908927 for center in centers {
909- inbounds &= Mask :: < isize , L > :: from (
928+ inbounds &= Simd :: < A , L > :: cast_mask (
910929 Simd :: gather_select_ptr ( aabb_ptrs, inbounds, Simd :: splat ( A :: NEG_INFINITY ) )
911930 . simd_ge ( * center - radii) ,
912931 ) ;
@@ -948,7 +967,7 @@ where
948967 let diff = vals - n_center[ k] ;
949968 dists_sq += diff * diff;
950969 }
951- A :: any ( dists_sq. simd_le ( rs_sq) )
970+ Simd :: < A , L > :: mask_any ( dists_sq. simd_le ( rs_sq) )
952971 } )
953972 } )
954973 }
@@ -1046,7 +1065,7 @@ unsafe fn median_partition<A: Axis, const K: usize>(points: &mut [[A; K]], k: us
10461065
10471066#[ cfg( test) ]
10481067mod tests {
1049- use rand:: { rngs :: SmallRng , Rng , SeedableRng } ;
1068+ use rand:: { Rng , SeedableRng , rngs :: SmallRng } ;
10501069
10511070 use super :: * ;
10521071
0 commit comments