Skip to content

Commit 4d28884

Browse files
committed
fix(simd): preserve NaN in simd_exp_f32 (codex review on PR #142)
The pre-clamp via simd_clamp silently destroyed NaN inputs. simd_clamp is implemented as max(lo).min(hi); _mm512_max_ps returns the SECOND operand when the first is NaN (per Intel SDM § MAXPS), so NaN got clamped to lo (-87.336) and exp(-87.336) ≈ 1.4e-38 — a tiny finite value pretending to be valid. Fix: capture NaN lanes via x.simd_ne(x) (NaN ≠ itself per IEEE 754) BEFORE the clamp, then mask-select NaN back into those lanes after the polynomial. NaN propagates per-lane; finite lanes are unchanged. Two regression tests: simd_exp_f32_propagates_nan — full-NaN vector returns full-NaN simd_exp_f32_propagates_nan_per_lane — mixed NaN/0.0 input; NaN lanes propagate, finite lanes compute exp(0)=1 unaffected 1788 passed (+2 from 1786). Reported-by: codex review on PR #142.
1 parent e566c33 commit 4d28884

1 file changed

Lines changed: 52 additions & 3 deletions

File tree

src/simd.rs

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1598,15 +1598,27 @@ pub fn f32_to_bf16_batch(input: &[f32], output: &mut [u16]) {
15981598
/// integer exponent `n` stays within the IEEE 754 f32 representable range.
15991599
/// Beyond the upper bound we'd hit `i32` overflow in `pow2n_from_int` and
16001600
/// silently return ~0.5 instead of +Inf (release) or panic (debug).
1601-
/// NaN passes through the polynomial as NaN (NaN comparisons in `simd_clamp`
1602-
/// take neither branch on standard implementations).
1601+
///
1602+
/// NaN handling: `simd_clamp` is `max(lo).min(hi)`, and `_mm512_max_ps` /
1603+
/// `_mm512_min_ps` return the SECOND operand when the first is NaN (per
1604+
/// Intel SDM § MAXPS/MINPS). That would silently clamp NaN inputs to `lo`
1605+
/// (-87.336) producing `exp(-87.336) ≈ 1.4e-38` — a finite tiny value
1606+
/// masquerading as valid output. Caught by codex review on PR #142.
1607+
///
1608+
/// Fix: capture NaN lanes via `x.simd_ne(x)` (NaN ≠ itself per IEEE 754)
1609+
/// before the clamp, then mask-select NaN back into those lanes after
1610+
/// the polynomial. NaN lanes propagate as NaN; finite lanes are unchanged.
16031611
#[inline(always)]
16041612
#[allow(dead_code)]
16051613
pub fn simd_exp_f32(x: F32x16) -> F32x16 {
16061614
let ln2 = F32x16::splat(core::f32::consts::LN_2);
16071615
let inv_ln2 = F32x16::splat(1.0 / core::f32::consts::LN_2);
16081616
let one = F32x16::splat(1.0);
16091617

1618+
// NaN-preservation mask: bit set wherever x is NaN. IEEE 754: NaN ≠ NaN.
1619+
// Captured BEFORE the clamp because simd_clamp destroys NaN lanes.
1620+
let nan_mask = x.simd_ne(x);
1621+
16101622
// Pre-clamp to the safe domain. Outside this band exp() is non-representable
16111623
// anyway (overflow → +Inf at ~88.7, underflow → +0 at ~-87.3) so the clamp
16121624
// is observable only at the saturation boundary.
@@ -1625,7 +1637,10 @@ pub fn simd_exp_f32(x: F32x16) -> F32x16 {
16251637
let poly = one + r * (one + r * (c2 + r * (c3 + r * (c4 + r * c5))));
16261638

16271639
// Reconstruct: exp(x) = 2^n * poly
1628-
poly * pow2n_from_int(n)
1640+
let result = poly * pow2n_from_int(n);
1641+
1642+
// Restore NaN in lanes where the input was NaN (clamp had destroyed them).
1643+
nan_mask.select(F32x16::splat(f32::NAN), result)
16291644
}
16301645

16311646
/// Compute 2^n where n is an integer stored as f32.
@@ -1842,6 +1857,40 @@ mod tests {
18421857
}
18431858
}
18441859

1860+
#[test]
1861+
fn simd_exp_f32_propagates_nan() {
1862+
// simd_clamp is max(lo).min(hi); _mm512_max_ps returns the SECOND
1863+
// operand on NaN, so without the nan_mask save/restore, NaN would
1864+
// be silently clamped to -87.336 → exp ≈ 1.4e-38 (a tiny finite
1865+
// value pretending to be valid). With the mask, NaN propagates.
1866+
// Per codex review on PR #142.
1867+
let nan = F32x16::splat(f32::NAN);
1868+
let result = simd_exp_f32(nan);
1869+
let arr = result.to_array();
1870+
for &v in &arr {
1871+
assert!(v.is_nan(), "exp(NaN) must propagate NaN, got {}", v);
1872+
}
1873+
}
1874+
1875+
#[test]
1876+
fn simd_exp_f32_propagates_nan_per_lane() {
1877+
// Mixed input: lanes 0,4,8,12 are NaN; rest are 0.0. Verify that
1878+
// NaN propagates only in those lanes; the others compute exp(0)=1.
1879+
let mut data = [0.0f32; 16];
1880+
for i in (0..16).step_by(4) {
1881+
data[i] = f32::NAN;
1882+
}
1883+
let result = simd_exp_f32(F32x16::from_array(data));
1884+
let arr = result.to_array();
1885+
for (i, &v) in arr.iter().enumerate() {
1886+
if i % 4 == 0 {
1887+
assert!(v.is_nan(), "lane {} should be NaN, got {}", i, v);
1888+
} else {
1889+
assert!((v - 1.0).abs() < 1e-4, "lane {} should be exp(0)=1, got {}", i, v);
1890+
}
1891+
}
1892+
}
1893+
18451894
#[test]
18461895
fn simd_exp_f32_handles_large_positive() {
18471896
// Without the clamp, x = 200 produced n = 288, ni + 127 = 415 which

0 commit comments

Comments
 (0)