Skip to content

Commit 1a73c37

Browse files
committed
fix(simd_half): preserve MXCSR across F16C cast batches (codex P2)
Per codex review on PR #183: `cast_f32_to_f16_batch_f16c` and `cast_f16_to_f32_batch_f16c` use F16C intrinsics that can raise FP exceptions (#O / #U / #P / #I / #D) on edge inputs — setting bits in the MXCSR status word. The scalar reference paths (`F16::to_f32`, `F16::from_f32_rounded`) are pure bit manipulation and never touch MXCSR, so the F16C fast path was introducing observable FP control-state side effects. Codex's proposed fix (`_mm256_cvtps_ph::<8>` with bit 3 set for `_MM_FROUND_NO_EXC`) does not apply here: the Rust stdarch intrinsic enforces `static_assert_uimm_bits!(IMM8, 3)` so IMM8 is constrained to `0..=7`, and the underlying VCVTPS2PH IMM8 encoding has no SAE bit — bit 3 selects MXCSR.RM (not NO_EXC, which is an AVX-512 convention). The only valid IMM8 values for F16C `_mm256_cvtps_ph` are 0..=3 (the four rounding modes). The actual fix: save MXCSR via STMXCSR before the SIMD region, restore via LDMXCSR after. Preserves every bit of the original control/status word (rounding mode, exception masks, flush-to- zero, and importantly the exception flag bits that the SIMD path may have set). Net effect: callers observe no MXCSR change vs. the scalar path. Implementation uses inline `asm!(stmxcsr/ldmxcsr)` rather than `_mm_getcsr` / `_mm_setcsr` because those wrappers are deprecated on stable Rust 1.95 (rustc deemed them unsound for cross-thread visibility reasons; the official guidance is exactly this — use inline asm). Two ops per batch call: one STMXCSR save at entry, one LDMXCSR restore at exit. Cost: ~5 cycles total, dwarfed by even a single 8-lane cvtps_ph chunk. New test `f16c_cast_preserves_mxcsr` exercises the fix: constructs input arrays containing 1e30 / -1e30 (overflow #O), 1e-30 (underflow / denormal #U / #D / #P), 1.0/3.0 (precision #P), NaN, Inf, ±0, 1.0 — values designed to trigger every relevant F16C exception. Snapshots MXCSR before, runs the cast, snapshots after, asserts byte-equal. Same check for the upcast direction with SNaN-encoded F16 inputs that trigger #I/#D in `_mm256_cvtph_ps`. Both pass on this host (F16C + avx2 silicon). Note: this fix does NOT prevent traps from firing on hosts where the caller has unmasked FP exceptions before calling us. Trap behaviour is the same as for any plain `a + b` of f32 that overflows — fires from the SIMD ops themselves, not under our control. Default MXCSR has all exception masks set (the process-startup state on Linux/macOS/Windows), so this is the common case and traps don't fire there. Verification: * 22 simd_half tests pass (was 21 before, +1 new MXCSR- preservation test). * Full lib sweep: 2087 tests pass. * cargo clippy -- -D warnings clean (no deprecation warning from _mm_getcsr / _mm_setcsr — we use inline asm instead). * cargo fmt --all --check clean. https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
1 parent 5074048 commit 1a73c37

1 file changed

Lines changed: 121 additions & 4 deletions

File tree

src/simd_half.rs

Lines changed: 121 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -426,12 +426,27 @@ pub fn cast_f32_to_f16_batch(src: &[f32], dst: &mut [F16]) {
426426
/// reference per IEEE 754 binary16 → binary32 spec** (lossless widening,
427427
/// no rounding possible).
428428
///
429+
/// # MXCSR preservation
430+
/// `_mm256_cvtph_ps` may raise `#I` (Invalid: SNaN input) or `#D`
431+
/// (Denormal) — setting bits in MXCSR that the scalar bit-fiddle
432+
/// reference [`F16::to_f32`] does not touch. To preserve the scalar
433+
/// path's contract of "no observable FP control/status side effects,"
434+
/// the MXCSR is saved before the SIMD region and restored after. Net
435+
/// effect: callers see no MXCSR change vs. the scalar path. (See
436+
/// codex review on PR #183.)
437+
///
429438
/// # Safety
430439
/// Caller must have feature-detected `f16c` + `avx` at runtime.
431440
#[cfg(target_arch = "x86_64")]
432441
#[target_feature(enable = "f16c,avx")]
433442
unsafe fn cast_f16_to_f32_batch_f16c(src: &[u16], dst: &mut [f32]) {
443+
use core::arch::asm;
434444
use core::arch::x86_64::{__m128i, _mm256_cvtph_ps, _mm256_storeu_ps, _mm_loadu_si128};
445+
let mut saved_mxcsr: u32 = 0;
446+
// SAFETY: STMXCSR writes the 32-bit MXCSR control/status register
447+
// to the provided memory location; available on any SSE host
448+
// (baseline x86_64).
449+
asm!("stmxcsr [{ptr}]", ptr = in(reg) &mut saved_mxcsr, options(nostack));
435450
let n = src.len().min(dst.len());
436451
let chunks = n / 8;
437452
for c in 0..chunks {
@@ -444,6 +459,11 @@ unsafe fn cast_f16_to_f32_batch_f16c(src: &[u16], dst: &mut [f32]) {
444459
for i in (chunks * 8)..n {
445460
dst[i] = F16(src[i]).to_f32();
446461
}
462+
// SAFETY: LDMXCSR reads the value we saved at the top — preserves
463+
// every bit of the original MXCSR (rounding mode, exception masks,
464+
// flush-to-zero etc.), clearing any exception flags the SIMD path
465+
// may have set.
466+
asm!("ldmxcsr [{ptr}]", ptr = in(reg) &saved_mxcsr, options(nostack, readonly));
447467
}
448468

449469
/// F16C-vectorized f32 → F16 batch with IEEE 754 RNE rounding.
@@ -452,17 +472,35 @@ unsafe fn cast_f16_to_f32_batch_f16c(src: &[u16], dst: &mut [f32]) {
452472
/// one xmm store). The const `IMM8 = 0` selects
453473
/// `_MM_FROUND_TO_NEAREST_INT` — round-to-nearest-even, matches the
454474
/// scalar reference [`F16::from_f32_rounded`] bit-for-bit on every
455-
/// input. (Intel's `IMM8` for this intrinsic is 3 bits wide so the
456-
/// `_MM_FROUND_NO_EXC` flag is not selectable here; exceptions are
457-
/// raised but we ignore them — they don't affect the produced bit
458-
/// pattern.)
475+
/// input.
476+
///
477+
/// # IMM8 encoding limit
478+
/// `_mm256_cvtps_ph`'s `IMM8` is 3 bits wide (`static_assert_uimm_bits!
479+
/// (IMM8, 3)` in the Rust stdarch wrapper). Valid values are `0..=3`
480+
/// (the four rounding modes — RNE, down, up, truncate). Bits 2-3 of
481+
/// the underlying VCVTPS2PH IMM8 encoding are "reserved" and "select
482+
/// MXCSR.RM" per Intel SDM — NOT `_MM_FROUND_NO_EXC`, which is an
483+
/// AVX-512 convention (`_mm512_cvtps_ph` accepts `NO_EXC`, F16C does
484+
/// not). Exception suppression is handled at the MXCSR level (below).
485+
///
486+
/// # MXCSR preservation
487+
/// `_mm256_cvtps_ph` may raise `#O` (Overflow), `#U` (Underflow),
488+
/// `#P` (Precision), `#I` (Invalid for SNaN), `#D` (Denormal). The
489+
/// scalar reference [`F16::from_f32_rounded`] is pure bit
490+
/// manipulation and never touches MXCSR. We save/restore MXCSR around
491+
/// the SIMD region so callers see no observable control/status side
492+
/// effects regardless of input data. (See codex review on PR #183.)
459493
///
460494
/// # Safety
461495
/// Caller must have feature-detected `f16c` + `avx` at runtime.
462496
#[cfg(target_arch = "x86_64")]
463497
#[target_feature(enable = "f16c,avx")]
464498
unsafe fn cast_f32_to_f16_batch_f16c(src: &[f32], dst: &mut [u16]) {
499+
use core::arch::asm;
465500
use core::arch::x86_64::{__m128i, _mm256_cvtps_ph, _mm256_loadu_ps, _mm_storeu_si128};
501+
let mut saved_mxcsr: u32 = 0;
502+
// SAFETY: STMXCSR writes the 32-bit MXCSR; baseline SSE op.
503+
asm!("stmxcsr [{ptr}]", ptr = in(reg) &mut saved_mxcsr, options(nostack));
466504
let n = src.len().min(dst.len());
467505
let chunks = n / 8;
468506
for c in 0..chunks {
@@ -475,6 +513,8 @@ unsafe fn cast_f32_to_f16_batch_f16c(src: &[f32], dst: &mut [u16]) {
475513
for i in (chunks * 8)..n {
476514
dst[i] = F16::from_f32_rounded(src[i]).0;
477515
}
516+
// SAFETY: LDMXCSR restores the saved value bit-for-bit.
517+
asm!("ldmxcsr [{ptr}]", ptr = in(reg) &saved_mxcsr, options(nostack, readonly));
478518
}
479519

480520
// ============================================================================
@@ -853,4 +893,81 @@ mod tests {
853893
assert_eq!(dst[i], expected[i], "mul_f16_inplace mismatch at {}", i);
854894
}
855895
}
896+
897+
/// Codex PR #183 P2: F16C `_mm256_cvtps_ph` may raise FP exceptions
898+
/// (#O on overflow, #U on underflow, #P on precision loss, #I on
899+
/// SNaN, #D on denormal input) which set bits in MXCSR. The scalar
900+
/// path is pure bit manipulation and never touches MXCSR. The fix:
901+
/// `cast_f32_to_f16_batch_f16c` saves MXCSR via STMXCSR before the
902+
/// SIMD region and restores it via LDMXCSR after. This test feeds
903+
/// inputs that should trigger every exception bit and asserts
904+
/// MXCSR is byte-identical before vs. after the call.
905+
#[cfg(target_arch = "x86_64")]
906+
#[test]
907+
fn f16c_cast_preserves_mxcsr() {
908+
if !std::is_x86_feature_detected!("f16c") {
909+
eprintln!("f16c not detected; skipping");
910+
return;
911+
}
912+
use core::arch::asm;
913+
914+
// Inputs designed to trigger #O / #U / #P / #I / #D in F16C
915+
// downcast:
916+
// - 1e30, -1e30 : overflow (out of F16 range ±65504) → #O
917+
// - 1e-30 : underflow / denormal → #U, #D, #P
918+
// - 1.0/3.0 : precision loss → #P
919+
// - f32::NAN : invalid (if it's an sNaN representation) → #I
920+
let inputs: Vec<f32> = vec![
921+
1e30,
922+
-1e30,
923+
1e-30,
924+
1.0 / 3.0,
925+
f32::NAN,
926+
f32::INFINITY,
927+
0.0,
928+
1.0,
929+
// Pad to 8 lanes so the SIMD chunk loop fires once with no tail.
930+
];
931+
assert_eq!(inputs.len(), 8);
932+
let mut out = vec![F16::ZERO; 8];
933+
934+
// Snapshot MXCSR before.
935+
let mut mxcsr_before: u32 = 0;
936+
unsafe {
937+
asm!("stmxcsr [{ptr}]", ptr = in(reg) &mut mxcsr_before, options(nostack));
938+
}
939+
940+
cast_f32_to_f16_batch(&inputs, &mut out);
941+
942+
// Snapshot MXCSR after.
943+
let mut mxcsr_after: u32 = 0;
944+
unsafe {
945+
asm!("stmxcsr [{ptr}]", ptr = in(reg) &mut mxcsr_after, options(nostack));
946+
}
947+
948+
assert_eq!(
949+
mxcsr_before, mxcsr_after,
950+
"cast_f32_to_f16_batch must not modify MXCSR (got 0x{:08x} before, 0x{:08x} after)",
951+
mxcsr_before, mxcsr_after
952+
);
953+
954+
// Same check for the upcast direction (`_mm256_cvtph_ps` can raise
955+
// #I/#D on SNaN/denormal F16 input).
956+
let f16_inputs: Vec<F16> = (0..8).map(|i| F16(0x7C01 + i as u16)).collect(); // SNaN-ish
957+
let mut f32_out = vec![0.0f32; 8];
958+
959+
unsafe {
960+
asm!("stmxcsr [{ptr}]", ptr = in(reg) &mut mxcsr_before, options(nostack));
961+
}
962+
cast_f16_to_f32_batch(&f16_inputs, &mut f32_out);
963+
unsafe {
964+
asm!("stmxcsr [{ptr}]", ptr = in(reg) &mut mxcsr_after, options(nostack));
965+
}
966+
967+
assert_eq!(
968+
mxcsr_before, mxcsr_after,
969+
"cast_f16_to_f32_batch must not modify MXCSR (got 0x{:08x} before, 0x{:08x} after)",
970+
mxcsr_before, mxcsr_after
971+
);
972+
}
856973
}

0 commit comments

Comments
 (0)