Skip to content

Commit 8878ff2

Browse files
committed
Add IEEE 754 f16 (half-precision) type to simd_avx512.rs — additive only
No existing code modified. New functions appended at end of file: Scalar (exact, all platforms): - f16_to_f32_ieee754: lossless widening (subnormals, Inf, NaN preserved) - f32_to_f16_ieee754_rne: narrowing with RNE (Round-to-Nearest-Even) Batch (runtime-detected, tiered): - f16_to_f32_batch_ieee754: AVX-512F (16-wide) → F16C (8-wide) → scalar - f32_to_f16_batch_ieee754_rne: AVX-512F (16-wide) → F16C (8-wide) → scalar Uses hardware F16C instructions (stable target_feature since Rust 1.68): VCVTPH2PS: u16 → f32 (exact) VCVTPS2PH: f32 → u16 (imm8=0x00 for RNE) IEEE 754 binary16: 1 sign + 5 exp (bias 15) + 10 mantissa Range: ±65504, precision: 3.31 decimal digits 6 new tests, all passing. Existing BF16 tests unaffected. https://claude.ai/code/session_017ZN5PNEf8boFBgorUZVrFU
1 parent 5dc9db3 commit 8878ff2

1 file changed

Lines changed: 356 additions & 0 deletions

File tree

src/simd_avx512.rs

Lines changed: 356 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2356,3 +2356,359 @@ mod bf16_tests {
23562356
}
23572357
}
23582358
}
2359+
2360+
// ════════════════════════════════════════════════════════════════════════════
2361+
// F16 (IEEE 754 Half-Precision) — via F16C instructions (stable since Rust 1.68)
2362+
//
2363+
// IEEE 754 binary16: 1 sign + 5 exponent + 10 mantissa
2364+
// Range: ±65504, precision: ~3.3 decimal digits
2365+
// Subnormals: ±5.96×10⁻⁸ minimum positive
2366+
//
2367+
// Hardware instructions (F16C, stable target_feature):
2368+
// _mm256_cvtph_ps: 8× f16(u16) → 8× f32 (VCVTPH2PS ymm, xmm)
2369+
// _mm512_cvtph_ps: 16× f16(u16) → 16× f32 (VCVTPH2PS zmm, ymm) [AVX-512F]
2370+
// _mm256_cvtps_ph: 8× f32 → 8× f16(u16) (VCVTPS2PH xmm, ymm, imm8)
2371+
// _mm512_cvtps_ph: 16× f32 → 16× f16(u16) (VCVTPS2PH ymm, zmm, imm8) [AVX-512F]
2372+
//
2373+
// imm8 for rounding:
2374+
// 0x00 = Round to nearest even (IEEE default)
2375+
// 0x01 = Round toward negative infinity
2376+
// 0x02 = Round toward positive infinity
2377+
// 0x03 = Round toward zero (truncate)
2378+
// 0x04 = Use MXCSR rounding mode
2379+
//
2380+
// NOTE: F16C is available on Haswell+ (2013), essentially all modern x86_64.
2381+
// AVX-512 F16C (zmm-width) requires AVX-512F.
2382+
// ════════════════════════════════════════════════════════════════════════════
2383+
2384+
/// IEEE 754 f16 → f32 scalar conversion (exact, lossless).
2385+
///
2386+
/// binary16: 1 sign | 5 exponent (bias 15) | 10 mantissa
2387+
/// binary32: 1 sign | 8 exponent (bias 127) | 23 mantissa
2388+
///
2389+
/// Conversion is exact: every f16 value has an exact f32 representation.
2390+
/// Zero additional error — this is a widening cast.
2391+
pub fn f16_to_f32_ieee754(bits: u16) -> f32 {
2392+
let sign = ((bits >> 15) & 1) as u32;
2393+
let exp = ((bits >> 10) & 0x1F) as u32;
2394+
let mant = (bits & 0x3FF) as u32;
2395+
2396+
if exp == 0 {
2397+
if mant == 0 {
2398+
// ±0.0
2399+
f32::from_bits(sign << 31)
2400+
} else {
2401+
// Subnormal: (−1)^sign × 2^(−14) × 0.mantissa
2402+
// Normalize: find leading 1 in mantissa, adjust exponent
2403+
let mut m = mant;
2404+
let mut e: i32 = 1 - 15; // subnormal effective exponent = 1 - bias
2405+
// Shift mantissa left until the implicit 1 is in bit 10
2406+
while m & 0x400 == 0 {
2407+
m <<= 1;
2408+
e -= 1;
2409+
}
2410+
m &= 0x3FF; // remove the implicit 1
2411+
let f32_exp = ((e + 127) as i32) as u32; // rebias to f32
2412+
f32::from_bits((sign << 31) | (f32_exp << 23) | (m << 13))
2413+
}
2414+
} else if exp == 31 {
2415+
// Inf or NaN — preserve NaN payload
2416+
let f32_mant = mant << 13; // widen 10-bit → 23-bit mantissa
2417+
f32::from_bits((sign << 31) | (0xFF << 23) | f32_mant)
2418+
} else {
2419+
// Normal: rebias exponent (bias 15 → bias 127) = exp + 112
2420+
let f32_exp = exp + 112; // avoids u32 underflow vs (exp - 15 + 127)
2421+
f32::from_bits((sign << 31) | (f32_exp << 23) | (mant << 13))
2422+
}
2423+
}
2424+
2425+
/// IEEE 754 f32 → f16 scalar with Round-to-Nearest-Even (RNE).
2426+
///
2427+
/// Matches hardware VCVTPS2PH with imm8=0x00 bit-exact.
2428+
/// Handles: normals, subnormals, overflow→Inf, NaN preservation.
2429+
///
2430+
/// Precision: 10 mantissa bits → 3.31 decimal digits.
2431+
/// Any f32 value outside [−65504, +65504] overflows to ±Inf.
2432+
pub fn f32_to_f16_ieee754_rne(v: f32) -> u16 {
2433+
let bits = v.to_bits();
2434+
let sign = (bits >> 31) & 1;
2435+
let exp = ((bits >> 23) & 0xFF) as i32;
2436+
let mant = bits & 0x7FFFFF;
2437+
2438+
if exp == 255 {
2439+
// Inf or NaN
2440+
if mant == 0 {
2441+
// Inf
2442+
((sign << 15) | (0x1F << 10)) as u16
2443+
} else {
2444+
// NaN: preserve as much payload as possible
2445+
// Quiet NaN bit (bit 22 in f32 → bit 9 in f16)
2446+
let h_mant = (mant >> 13) & 0x3FF;
2447+
// Ensure at least one mantissa bit set (to stay NaN)
2448+
let h_mant = if h_mant == 0 { 0x200 } else { h_mant }; // set quiet bit
2449+
((sign << 15) | (0x1F << 10) | h_mant) as u16
2450+
}
2451+
} else if exp == 0 && mant == 0 {
2452+
// ±0.0
2453+
(sign << 15) as u16
2454+
} else {
2455+
// Normal or subnormal f32 → f16
2456+
let unbiased = exp - 127; // true exponent
2457+
2458+
if unbiased > 15 {
2459+
// Overflow → ±Inf
2460+
((sign << 15) | (0x1F << 10)) as u16
2461+
} else if unbiased < -24 {
2462+
// Too small even for f16 subnormal → ±0
2463+
(sign << 15) as u16
2464+
} else if unbiased < -14 {
2465+
// f16 subnormal range: exponent would be 0, mantissa encodes value
2466+
// f16_value = (−1)^s × 2^(−14) × 0.mant
2467+
// shift = how many extra bits to shift right (−14 − unbiased)
2468+
let shift = (-14 - unbiased) as u32;
2469+
// Add implicit 1 to f32 mantissa, then shift right
2470+
let full_mant = mant | 0x800000; // 24 bits with implicit 1
2471+
// We need to map 24-bit mantissa to 10-bit with proper shift
2472+
let total_shift = 13 + shift; // 13 to go from 23→10, plus extra for subnormal
2473+
2474+
// Round-to-nearest-even
2475+
let truncated = full_mant >> total_shift;
2476+
let remainder = full_mant & ((1 << total_shift) - 1);
2477+
let halfway = 1 << (total_shift - 1);
2478+
2479+
let rounded = if remainder > halfway {
2480+
truncated + 1
2481+
} else if remainder == halfway {
2482+
// Ties to even: round up if truncated is odd
2483+
if truncated & 1 != 0 { truncated + 1 } else { truncated }
2484+
} else {
2485+
truncated
2486+
};
2487+
2488+
let h_mant = rounded & 0x3FF;
2489+
// If rounding overflowed into exponent range, it becomes a normal
2490+
let h_exp = if rounded > 0x3FF { 1u32 } else { 0u32 };
2491+
((sign << 15) | (h_exp << 10) | h_mant) as u16
2492+
} else {
2493+
// Normal f16 range
2494+
let h_exp = (unbiased + 15) as u32; // rebias: +15
2495+
// Round mantissa from 23 bits to 10 bits using RNE
2496+
let truncated = mant >> 13;
2497+
let remainder = mant & 0x1FFF; // lower 13 bits
2498+
let halfway = 0x1000; // 2^12
2499+
2500+
let rounded = if remainder > halfway {
2501+
truncated + 1
2502+
} else if remainder == halfway {
2503+
if truncated & 1 != 0 { truncated + 1 } else { truncated }
2504+
} else {
2505+
truncated
2506+
};
2507+
2508+
// Check if rounding overflowed mantissa (10 bits → 11 bits)
2509+
if rounded > 0x3FF {
2510+
// Carry into exponent
2511+
let h_exp = h_exp + 1;
2512+
if h_exp >= 31 {
2513+
// Overflow to Inf
2514+
((sign << 15) | (0x1F << 10)) as u16
2515+
} else {
2516+
((sign << 15) | (h_exp << 10)) as u16 // mantissa = 0 after carry
2517+
}
2518+
} else {
2519+
((sign << 15) | (h_exp << 10) | rounded) as u16
2520+
}
2521+
}
2522+
}
2523+
}
2524+
2525+
/// Batch f16 → f32 via AVX-512 VCVTPH2PS (16 lanes) with F16C fallback (8 lanes).
2526+
///
2527+
/// Detection: avx512f → 16-wide | f16c → 8-wide | scalar fallback
2528+
/// Conversion is exact (lossless widening).
2529+
pub fn f16_to_f32_batch_ieee754(input: &[u16], output: &mut [f32]) {
2530+
let n = input.len().min(output.len());
2531+
2532+
#[cfg(target_arch = "x86_64")]
2533+
{
2534+
// Tier 1: AVX-512F (16 lanes per instruction)
2535+
if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("f16c") {
2536+
let chunks16 = n / 16;
2537+
for c in 0..chunks16 {
2538+
unsafe {
2539+
// SAFETY: avx512f + f16c verified above.
2540+
let src = _mm256_loadu_si256(input[c*16..].as_ptr() as *const __m256i);
2541+
let dst = _mm512_cvtph_ps(src);
2542+
_mm512_storeu_ps(output[c*16..].as_mut_ptr(), dst);
2543+
}
2544+
}
2545+
// Scalar tail
2546+
for i in (chunks16*16)..n {
2547+
output[i] = f16_to_f32_ieee754(input[i]);
2548+
}
2549+
return;
2550+
}
2551+
// Tier 2: F16C (8 lanes per instruction, Haswell+)
2552+
if is_x86_feature_detected!("f16c") {
2553+
let chunks8 = n / 8;
2554+
for c in 0..chunks8 {
2555+
unsafe {
2556+
// SAFETY: f16c verified above.
2557+
let src = _mm_loadu_si128(input[c*8..].as_ptr() as *const __m128i);
2558+
let dst = _mm256_cvtph_ps(src);
2559+
_mm256_storeu_ps(output[c*8..].as_mut_ptr(), dst);
2560+
}
2561+
}
2562+
for i in (chunks8*8)..n {
2563+
output[i] = f16_to_f32_ieee754(input[i]);
2564+
}
2565+
return;
2566+
}
2567+
}
2568+
2569+
// Scalar fallback (exact)
2570+
for i in 0..n {
2571+
output[i] = f16_to_f32_ieee754(input[i]);
2572+
}
2573+
}
2574+
2575+
/// Batch f32 → f16 via AVX-512 VCVTPS2PH (16 lanes) with RNE rounding.
2576+
///
2577+
/// imm8 = 0x00: Round-to-Nearest-Even (IEEE 754 default).
2578+
/// Matches hardware behavior bit-exact.
2579+
pub fn f32_to_f16_batch_ieee754_rne(input: &[f32], output: &mut [u16]) {
2580+
let n = input.len().min(output.len());
2581+
2582+
#[cfg(target_arch = "x86_64")]
2583+
{
2584+
// Tier 1: AVX-512F (16 lanes, RNE via imm8=0)
2585+
if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("f16c") {
2586+
let chunks16 = n / 16;
2587+
for c in 0..chunks16 {
2588+
unsafe {
2589+
// SAFETY: avx512f + f16c verified above.
2590+
let src = _mm512_loadu_ps(input[c*16..].as_ptr());
2591+
// imm8=0x00: _MM_FROUND_TO_NEAREST_INT (RNE)
2592+
let dst: __m256i = _mm512_cvtps_ph::<0x00>(src);
2593+
_mm256_storeu_si256(output[c*16..].as_mut_ptr() as *mut __m256i, dst);
2594+
}
2595+
}
2596+
for i in (chunks16*16)..n {
2597+
output[i] = f32_to_f16_ieee754_rne(input[i]);
2598+
}
2599+
return;
2600+
}
2601+
// Tier 2: F16C (8 lanes, RNE)
2602+
if is_x86_feature_detected!("f16c") {
2603+
let chunks8 = n / 8;
2604+
for c in 0..chunks8 {
2605+
unsafe {
2606+
// SAFETY: f16c verified above.
2607+
let src = _mm256_loadu_ps(input[c*8..].as_ptr());
2608+
let dst: __m128i = _mm256_cvtps_ph::<0x00>(src);
2609+
_mm_storeu_si128(output[c*8..].as_mut_ptr() as *mut __m128i, dst);
2610+
}
2611+
}
2612+
for i in (chunks8*8)..n {
2613+
output[i] = f32_to_f16_ieee754_rne(input[i]);
2614+
}
2615+
return;
2616+
}
2617+
}
2618+
2619+
// Scalar RNE fallback
2620+
for i in 0..n {
2621+
output[i] = f32_to_f16_ieee754_rne(input[i]);
2622+
}
2623+
}
2624+
2625+
#[cfg(test)]
2626+
mod f16_tests {
2627+
use super::*;
2628+
2629+
#[test]
2630+
fn f16_ieee754_exact_values() {
2631+
// IEEE 754 binary16 exact test vectors
2632+
assert_eq!(f16_to_f32_ieee754(0x0000), 0.0); // +0
2633+
assert_eq!(f16_to_f32_ieee754(0x8000), -0.0); // −0
2634+
assert_eq!(f16_to_f32_ieee754(0x3C00), 1.0); // 1.0
2635+
assert_eq!(f16_to_f32_ieee754(0xBC00), -1.0); // −1.0
2636+
assert_eq!(f16_to_f32_ieee754(0x4000), 2.0); // 2.0
2637+
assert_eq!(f16_to_f32_ieee754(0x3800), 0.5); // 0.5
2638+
assert_eq!(f16_to_f32_ieee754(0x7BFF), 65504.0); // max normal
2639+
assert!(f16_to_f32_ieee754(0x7C00).is_infinite()); // +Inf
2640+
assert!(f16_to_f32_ieee754(0xFC00).is_infinite()); // −Inf
2641+
assert!(f16_to_f32_ieee754(0x7C01).is_nan()); // NaN
2642+
// Smallest positive subnormal: 2^(−24) ≈ 5.96e-8
2643+
let smallest_sub = f16_to_f32_ieee754(0x0001);
2644+
assert!((smallest_sub - 5.960464e-8).abs() < 1e-14);
2645+
}
2646+
2647+
#[test]
2648+
fn f16_rne_roundtrip_normals() {
2649+
// Every f16 normal → f32 → f16 must be identity
2650+
for exp in 1u16..31 {
2651+
for mant in (0u16..1024).step_by(17) {
2652+
let h = (exp << 10) | mant;
2653+
let f = f16_to_f32_ieee754(h);
2654+
let back = f32_to_f16_ieee754_rne(f);
2655+
assert_eq!(h, back,
2656+
"roundtrip failed: 0x{:04X} → {} → 0x{:04X}", h, f, back);
2657+
}
2658+
}
2659+
}
2660+
2661+
#[test]
2662+
fn f16_exact_representable_values() {
2663+
// Values that are exactly representable in f16 must roundtrip perfectly
2664+
let exact_values: &[f32] = &[
2665+
0.0, 1.0, -1.0, 2.0, -2.0, 0.5, -0.5, 0.25, 0.125,
2666+
65504.0, -65504.0, // max f16
2667+
0.000061035156, // smallest normal f16 (2^-14)
2668+
];
2669+
for &v in exact_values {
2670+
let h = f32_to_f16_ieee754_rne(v);
2671+
let back = f16_to_f32_ieee754(h);
2672+
assert_eq!(v, back,
2673+
"exact value roundtrip failed: {} → 0x{:04X} → {}", v, h, back);
2674+
}
2675+
}
2676+
2677+
#[test]
2678+
fn f16_overflow_to_inf() {
2679+
let big = 100000.0f32;
2680+
assert_eq!(f32_to_f16_ieee754_rne(big), 0x7C00); // +Inf
2681+
assert_eq!(f32_to_f16_ieee754_rne(-big), 0xFC00); // −Inf
2682+
}
2683+
2684+
#[test]
2685+
fn f16_batch_matches_scalar() {
2686+
let input: Vec<u16> = (0..200).map(|i| {
2687+
let v = (i as f32 - 100.0) * 0.5;
2688+
f32_to_f16_ieee754_rne(v)
2689+
}).collect();
2690+
let mut batch_out = vec![0.0f32; 200];
2691+
f16_to_f32_batch_ieee754(&input, &mut batch_out);
2692+
2693+
for (i, &h) in input.iter().enumerate() {
2694+
let scalar = f16_to_f32_ieee754(h);
2695+
assert_eq!(batch_out[i].to_bits(), scalar.to_bits(),
2696+
"batch/scalar mismatch at {}: batch=0x{:08X} scalar=0x{:08X}",
2697+
i, batch_out[i].to_bits(), scalar.to_bits());
2698+
}
2699+
}
2700+
2701+
#[test]
2702+
fn f32_to_f16_batch_rne_matches_scalar() {
2703+
let input: Vec<f32> = (0..200).map(|i| (i as f32 - 100.0) * 0.37).collect();
2704+
let mut batch_out = vec![0u16; 200];
2705+
f32_to_f16_batch_ieee754_rne(&input, &mut batch_out);
2706+
2707+
for (i, &v) in input.iter().enumerate() {
2708+
let scalar = f32_to_f16_ieee754_rne(v);
2709+
assert_eq!(batch_out[i], scalar,
2710+
"f32→f16 batch/scalar mismatch at {}: input={} batch=0x{:04X} scalar=0x{:04X}",
2711+
i, v, batch_out[i], scalar);
2712+
}
2713+
}
2714+
}

0 commit comments

Comments
 (0)