Skip to content

Commit 5f7cb4a

Browse files
breaking: refactor trait api of capt to make it easier to use
1 parent 514fe71 commit 5f7cb4a

2 files changed

Lines changed: 48 additions & 34 deletions

File tree

bench/src/kdt.rs

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
11
use std::{
22
mem::size_of,
3-
simd::{num::SimdInt, Simd, SupportedLaneCount},
4-
};
5-
6-
use capt::{Aabb, Axis, AxisSimd};
7-
8-
use std::simd::{
9-
cmp::{SimdPartialEq, SimdPartialOrd},
10-
ptr::SimdConstPtr,
11-
LaneCount, Mask,
3+
simd::{
4+
cmp::SimdPartialOrd, num::SimdInt, ptr::SimdConstPtr, LaneCount, Mask, Simd,
5+
SupportedLaneCount,
6+
},
127
};
138

149
use crate::{distsq, forward_pass, median_partition};
10+
use capt::{Aabb, AxisSimd, AxisSimdElement};
1511

1612
#[derive(Clone, Debug, PartialEq)]
1713
/// A power-of-two KD-tree.
@@ -219,17 +215,16 @@ fn forward_pass_simd<A, const K: usize, const L: usize>(
219215
centers: &[Simd<A, L>; K],
220216
) -> Simd<usize, L>
221217
where
222-
Simd<A, L>: SimdPartialOrd,
223-
Mask<isize, L>: From<<Simd<A, L> as SimdPartialEq>::Mask>,
224-
A: Axis + AxisSimd<<Simd<A, L> as SimdPartialEq>::Mask>,
218+
Simd<A, L>: AxisSimd<L>,
219+
A: AxisSimdElement,
225220
LaneCount<L>: SupportedLaneCount,
226221
{
227222
let mut i: Simd<usize, L> = Simd::splat(0);
228223
let mut k = 0;
229224
for _ in 0..tests.len().trailing_ones() {
230225
let test_ptrs = Simd::splat(tests.as_ptr()).wrapping_add(i);
231226
let relevant_tests = unsafe { Simd::gather_ptr(test_ptrs) };
232-
let cmp: Mask<isize, L> = centers[k].simd_ge(relevant_tests).into();
227+
let cmp: Mask<isize, L> = Simd::<A, L>::cast_mask(centers[k].simd_ge(relevant_tests));
233228

234229
let one = Simd::splat(1);
235230
i = (i << one) + one + (cmp.to_int().cast() & one);

capt/src/lib.rs

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,12 @@ use core::{
8383

8484
#[cfg(feature = "simd")]
8585
use core::{
86-
ops::{AddAssign, Mul},
86+
ops::{AddAssign, Mul, SubAssign},
8787
ptr,
8888
simd::{
89+
LaneCount, Mask, Simd, SimdElement, SupportedLaneCount,
8990
cmp::{SimdPartialEq, SimdPartialOrd},
9091
ptr::SimdConstPtr,
91-
LaneCount, Mask, Simd, SimdElement, SupportedLaneCount,
9292
},
9393
};
9494

@@ -192,15 +192,30 @@ pub trait Axis: PartialOrd + Copy + Sub<Output = Self> + Add<Output = Self> {
192192
fn square(self) -> Self;
193193
}
194194

195+
#[cfg(feature = "simd")]
196+
/// A trait used for SIMD elements.
197+
pub trait AxisSimdElement: SimdElement + Default + Axis {}
198+
195199
#[cfg(feature = "simd")]
196200
/// A trait used for masks over SIMD vectors, used for parallel querying on [`Capt`]s.
197201
///
198202
/// The interface for this trait should be considered unstable since the standard SIMD API may
199203
/// change with Rust versions.
200-
pub trait AxisSimd<M>: SimdElement + Default {
201-
#[must_use]
202-
/// Determine whether any element of this mask is set to `true`.
203-
fn any(mask: M) -> bool;
204+
pub trait AxisSimd<const L: usize>:
205+
Sized
206+
+ SimdPartialOrd
207+
+ Add<Output = Self>
208+
+ AddAssign
209+
+ Sub<Output = Self>
210+
+ SubAssign
211+
+ Mul<Output = Self>
212+
where
213+
LaneCount<L>: SupportedLaneCount,
214+
{
215+
/// Cast a mask for a SIMD vector into a mask of `isize`s.
216+
fn cast_mask(mask: <Self as SimdPartialEq>::Mask) -> Mask<isize, L>;
217+
/// Determine whether a mask contains any true elements.
218+
fn mask_any(mask: <Self as SimdPartialEq>::Mask) -> bool;
204219
}
205220

206221
/// An index type used for lookups into and out of arrays.
@@ -251,12 +266,18 @@ macro_rules! impl_axis {
251266
}
252267

253268
#[cfg(feature = "simd")]
254-
impl<const L: usize> AxisSimd<Mask<$tm, L>> for $t
269+
impl AxisSimdElement for $t {}
270+
271+
#[cfg(feature = "simd")]
272+
impl<const L: usize> AxisSimd<L> for Simd<$t, L>
255273
where
256274
LaneCount<L>: SupportedLaneCount,
257275
{
258-
fn any(mask: Mask<$tm, L>) -> bool {
259-
Mask::<$tm, L>::any(mask)
276+
fn cast_mask(mask: <Self as SimdPartialEq>::Mask) -> Mask<isize, L> {
277+
mask.into()
278+
}
279+
fn mask_any(mask: <Self as SimdPartialEq>::Mask) -> bool {
280+
mask.any()
260281
}
261282
}
262283
};
@@ -308,17 +329,17 @@ fn forward_pass_simd<A, const K: usize, const L: usize>(
308329
centers: &[Simd<A, L>; K],
309330
) -> Simd<isize, L>
310331
where
311-
Simd<A, L>: SimdPartialOrd,
312-
Mask<isize, L>: From<<Simd<A, L> as SimdPartialEq>::Mask>,
313-
A: Axis + AxisSimd<<Simd<A, L> as SimdPartialEq>::Mask>,
332+
Simd<A, L>: AxisSimd<L>,
333+
A: AxisSimdElement,
314334
LaneCount<L>: SupportedLaneCount,
315335
{
316336
let mut test_idxs: Simd<isize, L> = Simd::splat(0);
317337
let mut k = 0;
318338
for _ in 0..tests.len().trailing_ones() {
319339
let test_ptrs = Simd::splat(tests.as_ptr()).wrapping_offset(test_idxs);
320340
let relevant_tests: Simd<A, L> = unsafe { Simd::gather_ptr(test_ptrs) };
321-
let cmp_results: Mask<isize, L> = centers[k % K].simd_ge(relevant_tests).into();
341+
let cmp_results: Mask<isize, L> =
342+
Simd::<A, L>::cast_mask(centers[k % K].simd_ge(relevant_tests));
322343

323344
let one = Simd::splat(1);
324345
test_idxs = (test_idxs << one) + one + (cmp_results.to_int() & Simd::splat(1));
@@ -884,10 +905,8 @@ where
884905
pub fn collides_simd(&self, centers: &[Simd<A, L>; K], mut radii: Simd<A, L>) -> bool
885906
where
886907
LaneCount<L>: SupportedLaneCount,
887-
Simd<A, L>:
888-
SimdPartialOrd + Sub<Output = Simd<A, L>> + Mul<Output = Simd<A, L>> + AddAssign,
889-
Mask<isize, L>: From<<Simd<A, L> as SimdPartialEq>::Mask>,
890-
A: Axis + AxisSimd<<Simd<A, L> as SimdPartialEq>::Mask>,
908+
Simd<A, L>: AxisSimd<L>,
909+
A: AxisSimdElement,
891910
{
892911
radii += Simd::splat(self.r_point);
893912
let zs = forward_pass_simd(&self.tests, centers);
@@ -898,15 +917,15 @@ where
898917

899918
unsafe {
900919
for center in centers {
901-
inbounds &= Mask::<isize, L>::from(
920+
inbounds &= Simd::<A, L>::cast_mask(
902921
(Simd::gather_select_ptr(aabb_ptrs, inbounds, Simd::splat(A::NEG_INFINITY))
903922
- radii)
904923
.simd_le(*center),
905924
);
906925
aabb_ptrs = aabb_ptrs.wrapping_add(Simd::splat(1));
907926
}
908927
for center in centers {
909-
inbounds &= Mask::<isize, L>::from(
928+
inbounds &= Simd::<A, L>::cast_mask(
910929
Simd::gather_select_ptr(aabb_ptrs, inbounds, Simd::splat(A::NEG_INFINITY))
911930
.simd_ge(*center - radii),
912931
);
@@ -948,7 +967,7 @@ where
948967
let diff = vals - n_center[k];
949968
dists_sq += diff * diff;
950969
}
951-
A::any(dists_sq.simd_le(rs_sq))
970+
Simd::<A, L>::mask_any(dists_sq.simd_le(rs_sq))
952971
})
953972
})
954973
}
@@ -1046,7 +1065,7 @@ unsafe fn median_partition<A: Axis, const K: usize>(points: &mut [[A; K]], k: us
10461065

10471066
#[cfg(test)]
10481067
mod tests {
1049-
use rand::{rngs::SmallRng, Rng, SeedableRng};
1068+
use rand::{Rng, SeedableRng, rngs::SmallRng};
10501069

10511070
use super::*;
10521071

0 commit comments

Comments
 (0)