@@ -2356,3 +2356,359 @@ mod bf16_tests {
23562356 }
23572357 }
23582358}
2359+
2360+ // ════════════════════════════════════════════════════════════════════════════
2361+ // F16 (IEEE 754 Half-Precision) — via F16C instructions (stable since Rust 1.68)
2362+ //
2363+ // IEEE 754 binary16: 1 sign + 5 exponent + 10 mantissa
2364+ // Range: ±65504, precision: ~3.3 decimal digits
2365+ // Subnormals: ±5.96×10⁻⁸ minimum positive
2366+ //
2367+ // Hardware instructions (F16C, stable target_feature):
2368+ // _mm256_cvtph_ps: 8× f16(u16) → 8× f32 (VCVTPH2PS ymm, xmm)
2369+ // _mm512_cvtph_ps: 16× f16(u16) → 16× f32 (VCVTPH2PS zmm, ymm) [AVX-512F]
2370+ // _mm256_cvtps_ph: 8× f32 → 8× f16(u16) (VCVTPS2PH xmm, ymm, imm8)
2371+ // _mm512_cvtps_ph: 16× f32 → 16× f16(u16) (VCVTPS2PH ymm, zmm, imm8) [AVX-512F]
2372+ //
2373+ // imm8 for rounding:
2374+ // 0x00 = Round to nearest even (IEEE default)
2375+ // 0x01 = Round toward negative infinity
2376+ // 0x02 = Round toward positive infinity
2377+ // 0x03 = Round toward zero (truncate)
2378+ // 0x04 = Use MXCSR rounding mode
2379+ //
2380+ // NOTE: F16C is available on Haswell+ (2013), essentially all modern x86_64.
2381+ // AVX-512 F16C (zmm-width) requires AVX-512F.
2382+ // ════════════════════════════════════════════════════════════════════════════
2383+
2384+ /// IEEE 754 f16 → f32 scalar conversion (exact, lossless).
2385+ ///
2386+ /// binary16: 1 sign | 5 exponent (bias 15) | 10 mantissa
2387+ /// binary32: 1 sign | 8 exponent (bias 127) | 23 mantissa
2388+ ///
2389+ /// Conversion is exact: every f16 value has an exact f32 representation.
2390+ /// Zero additional error — this is a widening cast.
2391+ pub fn f16_to_f32_ieee754 ( bits : u16 ) -> f32 {
2392+ let sign = ( ( bits >> 15 ) & 1 ) as u32 ;
2393+ let exp = ( ( bits >> 10 ) & 0x1F ) as u32 ;
2394+ let mant = ( bits & 0x3FF ) as u32 ;
2395+
2396+ if exp == 0 {
2397+ if mant == 0 {
2398+ // ±0.0
2399+ f32:: from_bits ( sign << 31 )
2400+ } else {
2401+ // Subnormal: (−1)^sign × 2^(−14) × 0.mantissa
2402+ // Normalize: find leading 1 in mantissa, adjust exponent
2403+ let mut m = mant;
2404+ let mut e: i32 = 1 - 15 ; // subnormal effective exponent = 1 - bias
2405+ // Shift mantissa left until the implicit 1 is in bit 10
2406+ while m & 0x400 == 0 {
2407+ m <<= 1 ;
2408+ e -= 1 ;
2409+ }
2410+ m &= 0x3FF ; // remove the implicit 1
2411+ let f32_exp = ( ( e + 127 ) as i32 ) as u32 ; // rebias to f32
2412+ f32:: from_bits ( ( sign << 31 ) | ( f32_exp << 23 ) | ( m << 13 ) )
2413+ }
2414+ } else if exp == 31 {
2415+ // Inf or NaN — preserve NaN payload
2416+ let f32_mant = mant << 13 ; // widen 10-bit → 23-bit mantissa
2417+ f32:: from_bits ( ( sign << 31 ) | ( 0xFF << 23 ) | f32_mant)
2418+ } else {
2419+ // Normal: rebias exponent (bias 15 → bias 127) = exp + 112
2420+ let f32_exp = exp + 112 ; // avoids u32 underflow vs (exp - 15 + 127)
2421+ f32:: from_bits ( ( sign << 31 ) | ( f32_exp << 23 ) | ( mant << 13 ) )
2422+ }
2423+ }
2424+
2425+ /// IEEE 754 f32 → f16 scalar with Round-to-Nearest-Even (RNE).
2426+ ///
2427+ /// Matches hardware VCVTPS2PH with imm8=0x00 bit-exact.
2428+ /// Handles: normals, subnormals, overflow→Inf, NaN preservation.
2429+ ///
2430+ /// Precision: 10 mantissa bits → 3.31 decimal digits.
2431+ /// Any f32 value outside [−65504, +65504] overflows to ±Inf.
2432+ pub fn f32_to_f16_ieee754_rne ( v : f32 ) -> u16 {
2433+ let bits = v. to_bits ( ) ;
2434+ let sign = ( bits >> 31 ) & 1 ;
2435+ let exp = ( ( bits >> 23 ) & 0xFF ) as i32 ;
2436+ let mant = bits & 0x7FFFFF ;
2437+
2438+ if exp == 255 {
2439+ // Inf or NaN
2440+ if mant == 0 {
2441+ // Inf
2442+ ( ( sign << 15 ) | ( 0x1F << 10 ) ) as u16
2443+ } else {
2444+ // NaN: preserve as much payload as possible
2445+ // Quiet NaN bit (bit 22 in f32 → bit 9 in f16)
2446+ let h_mant = ( mant >> 13 ) & 0x3FF ;
2447+ // Ensure at least one mantissa bit set (to stay NaN)
2448+ let h_mant = if h_mant == 0 { 0x200 } else { h_mant } ; // set quiet bit
2449+ ( ( sign << 15 ) | ( 0x1F << 10 ) | h_mant) as u16
2450+ }
2451+ } else if exp == 0 && mant == 0 {
2452+ // ±0.0
2453+ ( sign << 15 ) as u16
2454+ } else {
2455+ // Normal or subnormal f32 → f16
2456+ let unbiased = exp - 127 ; // true exponent
2457+
2458+ if unbiased > 15 {
2459+ // Overflow → ±Inf
2460+ ( ( sign << 15 ) | ( 0x1F << 10 ) ) as u16
2461+ } else if unbiased < -24 {
2462+ // Too small even for f16 subnormal → ±0
2463+ ( sign << 15 ) as u16
2464+ } else if unbiased < -14 {
2465+ // f16 subnormal range: exponent would be 0, mantissa encodes value
2466+ // f16_value = (−1)^s × 2^(−14) × 0.mant
2467+ // shift = how many extra bits to shift right (−14 − unbiased)
2468+ let shift = ( -14 - unbiased) as u32 ;
2469+ // Add implicit 1 to f32 mantissa, then shift right
2470+ let full_mant = mant | 0x800000 ; // 24 bits with implicit 1
2471+ // We need to map 24-bit mantissa to 10-bit with proper shift
2472+ let total_shift = 13 + shift; // 13 to go from 23→10, plus extra for subnormal
2473+
2474+ // Round-to-nearest-even
2475+ let truncated = full_mant >> total_shift;
2476+ let remainder = full_mant & ( ( 1 << total_shift) - 1 ) ;
2477+ let halfway = 1 << ( total_shift - 1 ) ;
2478+
2479+ let rounded = if remainder > halfway {
2480+ truncated + 1
2481+ } else if remainder == halfway {
2482+ // Ties to even: round up if truncated is odd
2483+ if truncated & 1 != 0 { truncated + 1 } else { truncated }
2484+ } else {
2485+ truncated
2486+ } ;
2487+
2488+ let h_mant = rounded & 0x3FF ;
2489+ // If rounding overflowed into exponent range, it becomes a normal
2490+ let h_exp = if rounded > 0x3FF { 1u32 } else { 0u32 } ;
2491+ ( ( sign << 15 ) | ( h_exp << 10 ) | h_mant) as u16
2492+ } else {
2493+ // Normal f16 range
2494+ let h_exp = ( unbiased + 15 ) as u32 ; // rebias: +15
2495+ // Round mantissa from 23 bits to 10 bits using RNE
2496+ let truncated = mant >> 13 ;
2497+ let remainder = mant & 0x1FFF ; // lower 13 bits
2498+ let halfway = 0x1000 ; // 2^12
2499+
2500+ let rounded = if remainder > halfway {
2501+ truncated + 1
2502+ } else if remainder == halfway {
2503+ if truncated & 1 != 0 { truncated + 1 } else { truncated }
2504+ } else {
2505+ truncated
2506+ } ;
2507+
2508+ // Check if rounding overflowed mantissa (10 bits → 11 bits)
2509+ if rounded > 0x3FF {
2510+ // Carry into exponent
2511+ let h_exp = h_exp + 1 ;
2512+ if h_exp >= 31 {
2513+ // Overflow to Inf
2514+ ( ( sign << 15 ) | ( 0x1F << 10 ) ) as u16
2515+ } else {
2516+ ( ( sign << 15 ) | ( h_exp << 10 ) ) as u16 // mantissa = 0 after carry
2517+ }
2518+ } else {
2519+ ( ( sign << 15 ) | ( h_exp << 10 ) | rounded) as u16
2520+ }
2521+ }
2522+ }
2523+ }
2524+
2525+ /// Batch f16 → f32 via AVX-512 VCVTPH2PS (16 lanes) with F16C fallback (8 lanes).
2526+ ///
2527+ /// Detection: avx512f → 16-wide | f16c → 8-wide | scalar fallback
2528+ /// Conversion is exact (lossless widening).
2529+ pub fn f16_to_f32_batch_ieee754 ( input : & [ u16 ] , output : & mut [ f32 ] ) {
2530+ let n = input. len ( ) . min ( output. len ( ) ) ;
2531+
2532+ #[ cfg( target_arch = "x86_64" ) ]
2533+ {
2534+ // Tier 1: AVX-512F (16 lanes per instruction)
2535+ if is_x86_feature_detected ! ( "avx512f" ) && is_x86_feature_detected ! ( "f16c" ) {
2536+ let chunks16 = n / 16 ;
2537+ for c in 0 ..chunks16 {
2538+ unsafe {
2539+ // SAFETY: avx512f + f16c verified above.
2540+ let src = _mm256_loadu_si256 ( input[ c* 16 ..] . as_ptr ( ) as * const __m256i ) ;
2541+ let dst = _mm512_cvtph_ps ( src) ;
2542+ _mm512_storeu_ps ( output[ c* 16 ..] . as_mut_ptr ( ) , dst) ;
2543+ }
2544+ }
2545+ // Scalar tail
2546+ for i in ( chunks16* 16 ) ..n {
2547+ output[ i] = f16_to_f32_ieee754 ( input[ i] ) ;
2548+ }
2549+ return ;
2550+ }
2551+ // Tier 2: F16C (8 lanes per instruction, Haswell+)
2552+ if is_x86_feature_detected ! ( "f16c" ) {
2553+ let chunks8 = n / 8 ;
2554+ for c in 0 ..chunks8 {
2555+ unsafe {
2556+ // SAFETY: f16c verified above.
2557+ let src = _mm_loadu_si128 ( input[ c* 8 ..] . as_ptr ( ) as * const __m128i ) ;
2558+ let dst = _mm256_cvtph_ps ( src) ;
2559+ _mm256_storeu_ps ( output[ c* 8 ..] . as_mut_ptr ( ) , dst) ;
2560+ }
2561+ }
2562+ for i in ( chunks8* 8 ) ..n {
2563+ output[ i] = f16_to_f32_ieee754 ( input[ i] ) ;
2564+ }
2565+ return ;
2566+ }
2567+ }
2568+
2569+ // Scalar fallback (exact)
2570+ for i in 0 ..n {
2571+ output[ i] = f16_to_f32_ieee754 ( input[ i] ) ;
2572+ }
2573+ }
2574+
2575+ /// Batch f32 → f16 via AVX-512 VCVTPS2PH (16 lanes) with RNE rounding.
2576+ ///
2577+ /// imm8 = 0x00: Round-to-Nearest-Even (IEEE 754 default).
2578+ /// Matches hardware behavior bit-exact.
2579+ pub fn f32_to_f16_batch_ieee754_rne ( input : & [ f32 ] , output : & mut [ u16 ] ) {
2580+ let n = input. len ( ) . min ( output. len ( ) ) ;
2581+
2582+ #[ cfg( target_arch = "x86_64" ) ]
2583+ {
2584+ // Tier 1: AVX-512F (16 lanes, RNE via imm8=0)
2585+ if is_x86_feature_detected ! ( "avx512f" ) && is_x86_feature_detected ! ( "f16c" ) {
2586+ let chunks16 = n / 16 ;
2587+ for c in 0 ..chunks16 {
2588+ unsafe {
2589+ // SAFETY: avx512f + f16c verified above.
2590+ let src = _mm512_loadu_ps ( input[ c* 16 ..] . as_ptr ( ) ) ;
2591+ // imm8=0x00: _MM_FROUND_TO_NEAREST_INT (RNE)
2592+ let dst: __m256i = _mm512_cvtps_ph :: < 0x00 > ( src) ;
2593+ _mm256_storeu_si256 ( output[ c* 16 ..] . as_mut_ptr ( ) as * mut __m256i , dst) ;
2594+ }
2595+ }
2596+ for i in ( chunks16* 16 ) ..n {
2597+ output[ i] = f32_to_f16_ieee754_rne ( input[ i] ) ;
2598+ }
2599+ return ;
2600+ }
2601+ // Tier 2: F16C (8 lanes, RNE)
2602+ if is_x86_feature_detected ! ( "f16c" ) {
2603+ let chunks8 = n / 8 ;
2604+ for c in 0 ..chunks8 {
2605+ unsafe {
2606+ // SAFETY: f16c verified above.
2607+ let src = _mm256_loadu_ps ( input[ c* 8 ..] . as_ptr ( ) ) ;
2608+ let dst: __m128i = _mm256_cvtps_ph :: < 0x00 > ( src) ;
2609+ _mm_storeu_si128 ( output[ c* 8 ..] . as_mut_ptr ( ) as * mut __m128i , dst) ;
2610+ }
2611+ }
2612+ for i in ( chunks8* 8 ) ..n {
2613+ output[ i] = f32_to_f16_ieee754_rne ( input[ i] ) ;
2614+ }
2615+ return ;
2616+ }
2617+ }
2618+
2619+ // Scalar RNE fallback
2620+ for i in 0 ..n {
2621+ output[ i] = f32_to_f16_ieee754_rne ( input[ i] ) ;
2622+ }
2623+ }
2624+
2625+ #[ cfg( test) ]
2626+ mod f16_tests {
2627+ use super :: * ;
2628+
2629+ #[ test]
2630+ fn f16_ieee754_exact_values ( ) {
2631+ // IEEE 754 binary16 exact test vectors
2632+ assert_eq ! ( f16_to_f32_ieee754( 0x0000 ) , 0.0 ) ; // +0
2633+ assert_eq ! ( f16_to_f32_ieee754( 0x8000 ) , -0.0 ) ; // −0
2634+ assert_eq ! ( f16_to_f32_ieee754( 0x3C00 ) , 1.0 ) ; // 1.0
2635+ assert_eq ! ( f16_to_f32_ieee754( 0xBC00 ) , -1.0 ) ; // −1.0
2636+ assert_eq ! ( f16_to_f32_ieee754( 0x4000 ) , 2.0 ) ; // 2.0
2637+ assert_eq ! ( f16_to_f32_ieee754( 0x3800 ) , 0.5 ) ; // 0.5
2638+ assert_eq ! ( f16_to_f32_ieee754( 0x7BFF ) , 65504.0 ) ; // max normal
2639+ assert ! ( f16_to_f32_ieee754( 0x7C00 ) . is_infinite( ) ) ; // +Inf
2640+ assert ! ( f16_to_f32_ieee754( 0xFC00 ) . is_infinite( ) ) ; // −Inf
2641+ assert ! ( f16_to_f32_ieee754( 0x7C01 ) . is_nan( ) ) ; // NaN
2642+ // Smallest positive subnormal: 2^(−24) ≈ 5.96e-8
2643+ let smallest_sub = f16_to_f32_ieee754 ( 0x0001 ) ;
2644+ assert ! ( ( smallest_sub - 5.960464e-8 ) . abs( ) < 1e-14 ) ;
2645+ }
2646+
2647+ #[ test]
2648+ fn f16_rne_roundtrip_normals ( ) {
2649+ // Every f16 normal → f32 → f16 must be identity
2650+ for exp in 1u16 ..31 {
2651+ for mant in ( 0u16 ..1024 ) . step_by ( 17 ) {
2652+ let h = ( exp << 10 ) | mant;
2653+ let f = f16_to_f32_ieee754 ( h) ;
2654+ let back = f32_to_f16_ieee754_rne ( f) ;
2655+ assert_eq ! ( h, back,
2656+ "roundtrip failed: 0x{:04X} → {} → 0x{:04X}" , h, f, back) ;
2657+ }
2658+ }
2659+ }
2660+
2661+ #[ test]
2662+ fn f16_exact_representable_values ( ) {
2663+ // Values that are exactly representable in f16 must roundtrip perfectly
2664+ let exact_values: & [ f32 ] = & [
2665+ 0.0 , 1.0 , -1.0 , 2.0 , -2.0 , 0.5 , -0.5 , 0.25 , 0.125 ,
2666+ 65504.0 , -65504.0 , // max f16
2667+ 0.000061035156 , // smallest normal f16 (2^-14)
2668+ ] ;
2669+ for & v in exact_values {
2670+ let h = f32_to_f16_ieee754_rne ( v) ;
2671+ let back = f16_to_f32_ieee754 ( h) ;
2672+ assert_eq ! ( v, back,
2673+ "exact value roundtrip failed: {} → 0x{:04X} → {}" , v, h, back) ;
2674+ }
2675+ }
2676+
2677+ #[ test]
2678+ fn f16_overflow_to_inf ( ) {
2679+ let big = 100000.0f32 ;
2680+ assert_eq ! ( f32_to_f16_ieee754_rne( big) , 0x7C00 ) ; // +Inf
2681+ assert_eq ! ( f32_to_f16_ieee754_rne( -big) , 0xFC00 ) ; // −Inf
2682+ }
2683+
2684+ #[ test]
2685+ fn f16_batch_matches_scalar ( ) {
2686+ let input: Vec < u16 > = ( 0 ..200 ) . map ( |i| {
2687+ let v = ( i as f32 - 100.0 ) * 0.5 ;
2688+ f32_to_f16_ieee754_rne ( v)
2689+ } ) . collect ( ) ;
2690+ let mut batch_out = vec ! [ 0.0f32 ; 200 ] ;
2691+ f16_to_f32_batch_ieee754 ( & input, & mut batch_out) ;
2692+
2693+ for ( i, & h) in input. iter ( ) . enumerate ( ) {
2694+ let scalar = f16_to_f32_ieee754 ( h) ;
2695+ assert_eq ! ( batch_out[ i] . to_bits( ) , scalar. to_bits( ) ,
2696+ "batch/scalar mismatch at {}: batch=0x{:08X} scalar=0x{:08X}" ,
2697+ i, batch_out[ i] . to_bits( ) , scalar. to_bits( ) ) ;
2698+ }
2699+ }
2700+
2701+ #[ test]
2702+ fn f32_to_f16_batch_rne_matches_scalar ( ) {
2703+ let input: Vec < f32 > = ( 0 ..200 ) . map ( |i| ( i as f32 - 100.0 ) * 0.37 ) . collect ( ) ;
2704+ let mut batch_out = vec ! [ 0u16 ; 200 ] ;
2705+ f32_to_f16_batch_ieee754_rne ( & input, & mut batch_out) ;
2706+
2707+ for ( i, & v) in input. iter ( ) . enumerate ( ) {
2708+ let scalar = f32_to_f16_ieee754_rne ( v) ;
2709+ assert_eq ! ( batch_out[ i] , scalar,
2710+ "f32→f16 batch/scalar mismatch at {}: input={} batch=0x{:04X} scalar=0x{:04X}" ,
2711+ i, v, batch_out[ i] , scalar) ;
2712+ }
2713+ }
2714+ }
0 commit comments