Skip to content

Commit 056a1c9

Browse files
authored
Merge pull request #33 from AdaWorldAPI/claude/continue-session-0mAVa
feat(hpc): complete Pumpkin shopping list — 5 gap fills + JIT noise codegen
2 parents a0718f7 + f68c382 commit 056a1c9

7 files changed

Lines changed: 1289 additions & 2 deletions

File tree

src/hpc/aabb.rs

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,15 @@ unsafe fn aabb_intersect_batch_sse41(query: &Aabb, candidates: &[Aabb]) -> Vec<b
292292
/// assert!(!hits[1]);
293293
/// ```
294294
pub fn ray_aabb_slab_test_batch(ray: &Ray, aabbs: &[Aabb]) -> (Vec<bool>, Vec<f32>) {
295+
#[cfg(target_arch = "x86_64")]
296+
{
297+
if is_x86_feature_detected!("avx512f") && aabbs.len() >= 16 {
298+
// SAFETY: avx512f detected, enough AABBs for batch processing.
299+
unsafe {
300+
return ray_aabb_slab_test_avx512(ray, aabbs);
301+
}
302+
}
303+
}
295304
ray_aabb_slab_test_scalar(ray, aabbs)
296305
}
297306

@@ -320,6 +329,128 @@ fn ray_aabb_slab_test_scalar(ray: &Ray, aabbs: &[Aabb]) -> (Vec<bool>, Vec<f32>)
320329
(hits, t_values)
321330
}
322331

332+
/// AVX-512 batch ray-AABB slab test: processes 16 AABBs per iteration.
333+
///
334+
/// Broadcasts ray origin and inv_dir per axis, gathers candidate min/max
335+
/// coords into SoA arrays, computes slab intervals with `_mm512_min_ps` /
336+
/// `_mm512_max_ps`, and combines masks with `_mm512_cmp_ps_mask`.
337+
///
338+
/// # Safety
339+
/// Caller must ensure AVX-512F is available.
340+
#[cfg(target_arch = "x86_64")]
341+
#[target_feature(enable = "avx512f")]
342+
unsafe fn ray_aabb_slab_test_avx512(ray: &Ray, aabbs: &[Aabb]) -> (Vec<bool>, Vec<f32>) {
343+
use core::arch::x86_64::*;
344+
345+
let mut hits = Vec::with_capacity(aabbs.len());
346+
let mut t_values = Vec::with_capacity(aabbs.len());
347+
348+
// Broadcast ray origin and inv_dir per axis
349+
let orig_x = _mm512_set1_ps(ray.origin[0]);
350+
let orig_y = _mm512_set1_ps(ray.origin[1]);
351+
let orig_z = _mm512_set1_ps(ray.origin[2]);
352+
let inv_x = _mm512_set1_ps(ray.inv_dir[0]);
353+
let inv_y = _mm512_set1_ps(ray.inv_dir[1]);
354+
let inv_z = _mm512_set1_ps(ray.inv_dir[2]);
355+
let zero = _mm512_set1_ps(0.0);
356+
357+
// Process 16 AABBs at a time
358+
let chunks = aabbs.len() / 16;
359+
for c in 0..chunks {
360+
let base = c * 16;
361+
362+
// Gather min/max coords for 16 AABBs into SoA arrays
363+
let mut a_min_x = [0.0f32; 16];
364+
let mut a_max_x = [0.0f32; 16];
365+
let mut a_min_y = [0.0f32; 16];
366+
let mut a_max_y = [0.0f32; 16];
367+
let mut a_min_z = [0.0f32; 16];
368+
let mut a_max_z = [0.0f32; 16];
369+
370+
for i in 0..16 {
371+
let aabb = &aabbs[base + i];
372+
a_min_x[i] = aabb.min[0];
373+
a_max_x[i] = aabb.max[0];
374+
a_min_y[i] = aabb.min[1];
375+
a_max_y[i] = aabb.max[1];
376+
a_min_z[i] = aabb.min[2];
377+
a_max_z[i] = aabb.max[2];
378+
}
379+
380+
// SAFETY: arrays are 16-element, avx512f checked by caller.
381+
let v_min_x = _mm512_loadu_ps(a_min_x.as_ptr());
382+
let v_max_x = _mm512_loadu_ps(a_max_x.as_ptr());
383+
let v_min_y = _mm512_loadu_ps(a_min_y.as_ptr());
384+
let v_max_y = _mm512_loadu_ps(a_max_y.as_ptr());
385+
let v_min_z = _mm512_loadu_ps(a_min_z.as_ptr());
386+
let v_max_z = _mm512_loadu_ps(a_max_z.as_ptr());
387+
388+
// X axis: t1 = (min - origin) * inv_dir, t2 = (max - origin) * inv_dir
389+
let t1_x = _mm512_mul_ps(_mm512_sub_ps(v_min_x, orig_x), inv_x);
390+
let t2_x = _mm512_mul_ps(_mm512_sub_ps(v_max_x, orig_x), inv_x);
391+
let t_near_x = _mm512_min_ps(t1_x, t2_x);
392+
let t_far_x = _mm512_max_ps(t1_x, t2_x);
393+
394+
// Y axis
395+
let t1_y = _mm512_mul_ps(_mm512_sub_ps(v_min_y, orig_y), inv_y);
396+
let t2_y = _mm512_mul_ps(_mm512_sub_ps(v_max_y, orig_y), inv_y);
397+
let t_near_y = _mm512_min_ps(t1_y, t2_y);
398+
let t_far_y = _mm512_max_ps(t1_y, t2_y);
399+
400+
// Z axis
401+
let t1_z = _mm512_mul_ps(_mm512_sub_ps(v_min_z, orig_z), inv_z);
402+
let t2_z = _mm512_mul_ps(_mm512_sub_ps(v_max_z, orig_z), inv_z);
403+
let t_near_z = _mm512_min_ps(t1_z, t2_z);
404+
let t_far_z = _mm512_max_ps(t1_z, t2_z);
405+
406+
// t_enter = max(t_near_x, t_near_y, t_near_z)
407+
let t_enter = _mm512_max_ps(_mm512_max_ps(t_near_x, t_near_y), t_near_z);
408+
// t_exit = min(t_far_x, t_far_y, t_far_z)
409+
let t_exit = _mm512_min_ps(_mm512_min_ps(t_far_x, t_far_y), t_far_z);
410+
411+
// hit = t_enter <= t_exit AND t_exit >= 0
412+
// _CMP_LE_OQ = 18, _CMP_GE_OQ = 29 (ordered, quiet)
413+
let m_le = _mm512_cmp_ps_mask::<{ _CMP_LE_OQ }>(t_enter, t_exit);
414+
let m_ge = _mm512_cmp_ps_mask::<{ _CMP_GE_OQ }>(t_exit, zero);
415+
let hit_mask = m_le & m_ge;
416+
417+
// Clamp t_enter to 0 for origins inside box
418+
let t_enter_clamped = _mm512_max_ps(t_enter, zero);
419+
420+
// SAFETY: 16-element array matches __m512 lane count.
421+
let mut t_arr = [0.0f32; 16];
422+
_mm512_storeu_ps(t_arr.as_mut_ptr(), t_enter_clamped);
423+
424+
for i in 0..16 {
425+
let hit = (hit_mask >> i) & 1 != 0;
426+
hits.push(hit);
427+
t_values.push(if hit { t_arr[i] } else { f32::MAX });
428+
}
429+
}
430+
431+
// Scalar tail for remainder
432+
for i in (chunks * 16)..aabbs.len() {
433+
let aabb = &aabbs[i];
434+
let mut t_enter = f32::NEG_INFINITY;
435+
let mut t_exit = f32::INFINITY;
436+
437+
for axis in 0..3 {
438+
let t1 = (aabb.min[axis] - ray.origin[axis]) * ray.inv_dir[axis];
439+
let t2 = (aabb.max[axis] - ray.origin[axis]) * ray.inv_dir[axis];
440+
let t_near = t1.min(t2);
441+
let t_far = t1.max(t2);
442+
t_enter = t_enter.max(t_near);
443+
t_exit = t_exit.min(t_far);
444+
}
445+
446+
let hit = t_enter <= t_exit && t_exit >= 0.0;
447+
hits.push(hit);
448+
t_values.push(if hit { t_enter.max(0.0) } else { f32::MAX });
449+
}
450+
451+
(hits, t_values)
452+
}
453+
323454
/// Expand all AABBs in-place by `(dx, dy, dz)` in both directions per axis.
324455
pub fn aabb_expand_batch(aabbs: &mut [Aabb], dx: f32, dy: f32, dz: f32) {
325456
#[cfg(target_arch = "x86_64")]
@@ -722,4 +853,28 @@ mod tests {
722853
assert!(approx_eq(ray.inv_dir[0], 0.5));
723854
assert!(ray.inv_dir[1].is_infinite());
724855
}
856+
857+
// ---------- AVX-512 ray-AABB parity ----------
858+
859+
#[test]
860+
fn test_ray_aabb_avx512_parity() {
861+
// 100 AABBs to exercise AVX-512 + tail
862+
let ray = Ray::new([0.0, 0.5, 0.5], [1.0, 0.0, 0.0]);
863+
let aabbs: Vec<Aabb> = (0..100)
864+
.map(|i| {
865+
let f = i as f32;
866+
Aabb::new([f, 0.0, 0.0], [f + 1.0, 1.0, 1.0])
867+
})
868+
.collect();
869+
let (hits_batch, ts_batch) = ray_aabb_slab_test_batch(&ray, &aabbs);
870+
let (hits_scalar, ts_scalar) = ray_aabb_slab_test_scalar(&ray, &aabbs);
871+
assert_eq!(hits_batch, hits_scalar, "ray AVX-512 hit parity");
872+
for i in 0..100 {
873+
assert!(
874+
approx_eq(ts_batch[i], ts_scalar[i]),
875+
"ray AVX-512 t parity at {i}: {} vs {}",
876+
ts_batch[i], ts_scalar[i]
877+
);
878+
}
879+
}
725880
}

0 commit comments

Comments
 (0)