@@ -292,6 +292,15 @@ unsafe fn aabb_intersect_batch_sse41(query: &Aabb, candidates: &[Aabb]) -> Vec<b
292292/// assert!(!hits[1]);
293293/// ```
294294pub fn ray_aabb_slab_test_batch ( ray : & Ray , aabbs : & [ Aabb ] ) -> ( Vec < bool > , Vec < f32 > ) {
295+ #[ cfg( target_arch = "x86_64" ) ]
296+ {
297+ if is_x86_feature_detected ! ( "avx512f" ) && aabbs. len ( ) >= 16 {
298+ // SAFETY: avx512f detected, enough AABBs for batch processing.
299+ unsafe {
300+ return ray_aabb_slab_test_avx512 ( ray, aabbs) ;
301+ }
302+ }
303+ }
295304 ray_aabb_slab_test_scalar ( ray, aabbs)
296305}
297306
@@ -320,6 +329,128 @@ fn ray_aabb_slab_test_scalar(ray: &Ray, aabbs: &[Aabb]) -> (Vec<bool>, Vec<f32>)
320329 ( hits, t_values)
321330}
322331
332+ /// AVX-512 batch ray-AABB slab test: processes 16 AABBs per iteration.
333+ ///
334+ /// Broadcasts ray origin and inv_dir per axis, gathers candidate min/max
335+ /// coords into SoA arrays, computes slab intervals with `_mm512_min_ps` /
336+ /// `_mm512_max_ps`, and combines masks with `_mm512_cmp_ps_mask`.
337+ ///
338+ /// # Safety
339+ /// Caller must ensure AVX-512F is available.
340+ #[ cfg( target_arch = "x86_64" ) ]
341+ #[ target_feature( enable = "avx512f" ) ]
342+ unsafe fn ray_aabb_slab_test_avx512 ( ray : & Ray , aabbs : & [ Aabb ] ) -> ( Vec < bool > , Vec < f32 > ) {
343+ use core:: arch:: x86_64:: * ;
344+
345+ let mut hits = Vec :: with_capacity ( aabbs. len ( ) ) ;
346+ let mut t_values = Vec :: with_capacity ( aabbs. len ( ) ) ;
347+
348+ // Broadcast ray origin and inv_dir per axis
349+ let orig_x = _mm512_set1_ps ( ray. origin [ 0 ] ) ;
350+ let orig_y = _mm512_set1_ps ( ray. origin [ 1 ] ) ;
351+ let orig_z = _mm512_set1_ps ( ray. origin [ 2 ] ) ;
352+ let inv_x = _mm512_set1_ps ( ray. inv_dir [ 0 ] ) ;
353+ let inv_y = _mm512_set1_ps ( ray. inv_dir [ 1 ] ) ;
354+ let inv_z = _mm512_set1_ps ( ray. inv_dir [ 2 ] ) ;
355+ let zero = _mm512_set1_ps ( 0.0 ) ;
356+
357+ // Process 16 AABBs at a time
358+ let chunks = aabbs. len ( ) / 16 ;
359+ for c in 0 ..chunks {
360+ let base = c * 16 ;
361+
362+ // Gather min/max coords for 16 AABBs into SoA arrays
363+ let mut a_min_x = [ 0.0f32 ; 16 ] ;
364+ let mut a_max_x = [ 0.0f32 ; 16 ] ;
365+ let mut a_min_y = [ 0.0f32 ; 16 ] ;
366+ let mut a_max_y = [ 0.0f32 ; 16 ] ;
367+ let mut a_min_z = [ 0.0f32 ; 16 ] ;
368+ let mut a_max_z = [ 0.0f32 ; 16 ] ;
369+
370+ for i in 0 ..16 {
371+ let aabb = & aabbs[ base + i] ;
372+ a_min_x[ i] = aabb. min [ 0 ] ;
373+ a_max_x[ i] = aabb. max [ 0 ] ;
374+ a_min_y[ i] = aabb. min [ 1 ] ;
375+ a_max_y[ i] = aabb. max [ 1 ] ;
376+ a_min_z[ i] = aabb. min [ 2 ] ;
377+ a_max_z[ i] = aabb. max [ 2 ] ;
378+ }
379+
380+ // SAFETY: arrays are 16-element, avx512f checked by caller.
381+ let v_min_x = _mm512_loadu_ps ( a_min_x. as_ptr ( ) ) ;
382+ let v_max_x = _mm512_loadu_ps ( a_max_x. as_ptr ( ) ) ;
383+ let v_min_y = _mm512_loadu_ps ( a_min_y. as_ptr ( ) ) ;
384+ let v_max_y = _mm512_loadu_ps ( a_max_y. as_ptr ( ) ) ;
385+ let v_min_z = _mm512_loadu_ps ( a_min_z. as_ptr ( ) ) ;
386+ let v_max_z = _mm512_loadu_ps ( a_max_z. as_ptr ( ) ) ;
387+
388+ // X axis: t1 = (min - origin) * inv_dir, t2 = (max - origin) * inv_dir
389+ let t1_x = _mm512_mul_ps ( _mm512_sub_ps ( v_min_x, orig_x) , inv_x) ;
390+ let t2_x = _mm512_mul_ps ( _mm512_sub_ps ( v_max_x, orig_x) , inv_x) ;
391+ let t_near_x = _mm512_min_ps ( t1_x, t2_x) ;
392+ let t_far_x = _mm512_max_ps ( t1_x, t2_x) ;
393+
394+ // Y axis
395+ let t1_y = _mm512_mul_ps ( _mm512_sub_ps ( v_min_y, orig_y) , inv_y) ;
396+ let t2_y = _mm512_mul_ps ( _mm512_sub_ps ( v_max_y, orig_y) , inv_y) ;
397+ let t_near_y = _mm512_min_ps ( t1_y, t2_y) ;
398+ let t_far_y = _mm512_max_ps ( t1_y, t2_y) ;
399+
400+ // Z axis
401+ let t1_z = _mm512_mul_ps ( _mm512_sub_ps ( v_min_z, orig_z) , inv_z) ;
402+ let t2_z = _mm512_mul_ps ( _mm512_sub_ps ( v_max_z, orig_z) , inv_z) ;
403+ let t_near_z = _mm512_min_ps ( t1_z, t2_z) ;
404+ let t_far_z = _mm512_max_ps ( t1_z, t2_z) ;
405+
406+ // t_enter = max(t_near_x, t_near_y, t_near_z)
407+ let t_enter = _mm512_max_ps ( _mm512_max_ps ( t_near_x, t_near_y) , t_near_z) ;
408+ // t_exit = min(t_far_x, t_far_y, t_far_z)
409+ let t_exit = _mm512_min_ps ( _mm512_min_ps ( t_far_x, t_far_y) , t_far_z) ;
410+
411+ // hit = t_enter <= t_exit AND t_exit >= 0
412+ // _CMP_LE_OQ = 18, _CMP_GE_OQ = 29 (ordered, quiet)
413+ let m_le = _mm512_cmp_ps_mask :: < { _CMP_LE_OQ } > ( t_enter, t_exit) ;
414+ let m_ge = _mm512_cmp_ps_mask :: < { _CMP_GE_OQ } > ( t_exit, zero) ;
415+ let hit_mask = m_le & m_ge;
416+
417+ // Clamp t_enter to 0 for origins inside box
418+ let t_enter_clamped = _mm512_max_ps ( t_enter, zero) ;
419+
420+ // SAFETY: 16-element array matches __m512 lane count.
421+ let mut t_arr = [ 0.0f32 ; 16 ] ;
422+ _mm512_storeu_ps ( t_arr. as_mut_ptr ( ) , t_enter_clamped) ;
423+
424+ for i in 0 ..16 {
425+ let hit = ( hit_mask >> i) & 1 != 0 ;
426+ hits. push ( hit) ;
427+ t_values. push ( if hit { t_arr[ i] } else { f32:: MAX } ) ;
428+ }
429+ }
430+
431+ // Scalar tail for remainder
432+ for i in ( chunks * 16 ) ..aabbs. len ( ) {
433+ let aabb = & aabbs[ i] ;
434+ let mut t_enter = f32:: NEG_INFINITY ;
435+ let mut t_exit = f32:: INFINITY ;
436+
437+ for axis in 0 ..3 {
438+ let t1 = ( aabb. min [ axis] - ray. origin [ axis] ) * ray. inv_dir [ axis] ;
439+ let t2 = ( aabb. max [ axis] - ray. origin [ axis] ) * ray. inv_dir [ axis] ;
440+ let t_near = t1. min ( t2) ;
441+ let t_far = t1. max ( t2) ;
442+ t_enter = t_enter. max ( t_near) ;
443+ t_exit = t_exit. min ( t_far) ;
444+ }
445+
446+ let hit = t_enter <= t_exit && t_exit >= 0.0 ;
447+ hits. push ( hit) ;
448+ t_values. push ( if hit { t_enter. max ( 0.0 ) } else { f32:: MAX } ) ;
449+ }
450+
451+ ( hits, t_values)
452+ }
453+
323454/// Expand all AABBs in-place by `(dx, dy, dz)` in both directions per axis.
324455pub fn aabb_expand_batch ( aabbs : & mut [ Aabb ] , dx : f32 , dy : f32 , dz : f32 ) {
325456 #[ cfg( target_arch = "x86_64" ) ]
@@ -722,4 +853,28 @@ mod tests {
722853 assert ! ( approx_eq( ray. inv_dir[ 0 ] , 0.5 ) ) ;
723854 assert ! ( ray. inv_dir[ 1 ] . is_infinite( ) ) ;
724855 }
856+
857+ // ---------- AVX-512 ray-AABB parity ----------
858+
859+ #[ test]
860+ fn test_ray_aabb_avx512_parity ( ) {
861+ // 100 AABBs to exercise AVX-512 + tail
862+ let ray = Ray :: new ( [ 0.0 , 0.5 , 0.5 ] , [ 1.0 , 0.0 , 0.0 ] ) ;
863+ let aabbs: Vec < Aabb > = ( 0 ..100 )
864+ . map ( |i| {
865+ let f = i as f32 ;
866+ Aabb :: new ( [ f, 0.0 , 0.0 ] , [ f + 1.0 , 1.0 , 1.0 ] )
867+ } )
868+ . collect ( ) ;
869+ let ( hits_batch, ts_batch) = ray_aabb_slab_test_batch ( & ray, & aabbs) ;
870+ let ( hits_scalar, ts_scalar) = ray_aabb_slab_test_scalar ( & ray, & aabbs) ;
871+ assert_eq ! ( hits_batch, hits_scalar, "ray AVX-512 hit parity" ) ;
872+ for i in 0 ..100 {
873+ assert ! (
874+ approx_eq( ts_batch[ i] , ts_scalar[ i] ) ,
875+ "ray AVX-512 t parity at {i}: {} vs {}" ,
876+ ts_batch[ i] , ts_scalar[ i]
877+ ) ;
878+ }
879+ }
725880}
0 commit comments