From f59254ce46dadd19c74f33fcf404ed009ba61a3d Mon Sep 17 00:00:00 2001 From: Dmitriy Vasilev Date: Thu, 30 Apr 2026 01:03:37 +0700 Subject: [PATCH] fix(format): FP16 and BF16 canonical IEEE-754 bit-cast codecs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit FP16 bugs fixed: - f32ToFp16: replace broken frexp-based encoder with direct bit manipulation - fp16ToF32: fix double-add bug (was computing 1 + (1 + m/1024)), fix denormal exponent (was -16 instead of -14), add Inf/NaN decoding BF16 bugs fixed (same as PR #47): - Replace frexp with (bits +| rounding) >> 16 canonical implementation - Full IEEE-754 exponent range (-126..+127) instead of clamped ±7 Add comprehensive test coverage for FP16 and BF16: special values, large values, denormals, overflow, quantizeValue round-trips. Related: #29, #22 --- src/formats/formats_root.zig | 154 +++++++++++++++++++++++++---------- 1 file changed, 113 insertions(+), 41 deletions(-) diff --git a/src/formats/formats_root.zig b/src/formats/formats_root.zig index 8b443b7..e68cd52 100644 --- a/src/formats/formats_root.zig +++ b/src/formats/formats_root.zig @@ -118,62 +118,62 @@ pub fn f32ToGf16(a: f32) u16 { // ═══════════════════════════════════════════════════════════════════ // Software fp16 encode/decode (IEEE 754 binary16) fn f32ToFp16(a: f32) u16 { - if (a == 0) return 0; - if (std.math.isInf(a)) return 0x7C00; // Infinity - if (std.math.isNan(a)) return 0x7E00; // NaN + if (std.math.isNan(a)) return 0x7E00; + const bits: u32 = @bitCast(a); + const sign: u16 = @intCast((bits >> 16) & 0x8000); + const abs_bits = bits & 0x7FFFFFFF; - const sign_bit: u16 = if (a < 0) 0x8000 else 0; - const abs_a = if (a < 0) -a else a; + if (abs_bits == 0) return sign; + if (std.math.isInf(a)) return sign | 0x7C00; - const frexp_result = std.math.frexp(abs_a); - const m_val = frexp_result.significand * 2.0; - var e = frexp_result.exponent - 1; + const f32_exp = @as(i32, @intCast((abs_bits >> 23) & 0xFF)) - 127; + const f32_mant = abs_bits & 0x7FFFFF; - e = @min(e, 15); - if (e <= -10) { - // Underflow -> zero - return sign_bit; - } + if (f32_exp > 15) return sign | 0x7C00; - const mant_f = (m_val - 1.0) * 1024.0; // 2^10 - var mant_i = @as(i32, @intFromFloat(mant_f)); - - if (mant_i == 1024) { - mant_i = 1023; - e += 1; - if (e >= 31) return 0x7C00; // Overflow + if (f32_exp >= -14) { + const fp16_mant = @as(u16, @intCast(f32_mant >> 13)); + const fp16_exp = @as(u16, @intCast(f32_exp + 15)) << 10; + return sign | fp16_exp | fp16_mant; } - const mant_bits: u16 = @as(u16, @intCast(mant_i)) & 0x03FF; - const e_bits: u16 = @as(u16, @intCast(e + 15)) << 10; - return sign_bit | e_bits | mant_bits; + const shift = @as(u5, @intCast(@as(i32, 13) - f32_exp - 14 + 1)); + if (shift >= 32) return sign; + const fp16_mant = @as(u16, @intCast(f32_mant >> shift)); + if (fp16_mant == 0) return sign; + return sign | fp16_mant; } fn fp16ToF32(x: u16) f32 { - if (x == 0) return 0.0; - if (x == 0x8000) return -0.0; - - const sign = @as(i32, (x >> 15) & 0x1); - const e = @as(i32, (x >> 10) & 0x1F); - const m = @as(i32, x & 0x03FF); + const sign: u32 = @as(u32, x & 0x8000) << 16; + const e = (x >> 10) & 0x1F; + const m = x & 0x03FF; if (e == 0) { - // Denormal: m in [1, 1023], value = m * 2^(-14) - const frac = @as(f32, @floatFromInt(m)) / 1024.0; - const exp = @as(f32, @floatFromInt(e - 1 - 15)); - const val = frac * std.math.pow(f32, 2.0, exp); - return if (sign != 0) -val else val; - } else { - const frac = @as(f32, @floatFromInt(m + 1024)) / 1024.0; - const exp = @as(f32, @floatFromInt(e - 15)); - const val = (1.0 + frac) * std.math.pow(f32, 2.0, exp); - return if (sign != 0) -val else val; + if (m == 0) return @bitCast(sign); + var mant = @as(u32, m) << 13; + var exp: u32 = 0; + while ((mant & 0x00800000) == 0) : (exp -= 1) { + mant <<= 1; + } + const f32_bits = sign | ((112 - exp) << 23) | (mant & 0x7FFFFF); + return @bitCast(f32_bits); + } + if (e == 0x1F) { + if (m == 0) return @bitCast(sign | 0x7F800000); + return @bitCast(sign | 0x7FC00000); } + + const f32_bits = sign | ((@as(u32, e) + 112) << 23) | (@as(u32, m) << 13); + return @bitCast(f32_bits); } -// Software bf16 encode/decode (Brain Float 16) +// Software bf16 encode/decode (Brain Float 16) — IEEE 754 canonical fn f32ToBf16(a: f32) u16 { - return @intCast(@as(u32, @bitCast(a)) >> 16); + if (std.math.isNan(a)) return 0x7FC0; + const bits: u32 = @bitCast(a); + const rounding: u32 = ((bits >> 16) & 1) + 0x7FFF; + return @intCast((bits +| rounding) >> 16); } fn bf16ToF32(x: u16) f32 { @@ -568,10 +568,82 @@ test "BF16: roundtrip small values" { } } +test "FP16: roundtrip basic values" { + const values = [_]f32{ 1.0, -1.0, 0.5, -0.5, 2.0, -2.0, 0.1, 0.25, 1.5 }; + for (values) |v| { + const fp16 = f32ToFp16(v); + const recovered = fp16ToF32(fp16); + const err = @abs(recovered - v) / @max(@abs(v), 1e-30); + try std.testing.expect(err < 0.005); + } +} + +test "FP16: special values" { + try std.testing.expectEqual(@as(u16, 0x0000), f32ToFp16(0.0)); + try std.testing.expectEqual(@as(u16, 0x8000), f32ToFp16(-0.0)); + try std.testing.expectEqual(@as(u16, 0x7C00), f32ToFp16(std.math.inf(f32))); + try std.testing.expectEqual(@as(u16, 0xFC00), f32ToFp16(-std.math.inf(f32))); + const nan_enc = f32ToFp16(std.math.nan(f32)); + try std.testing.expect(std.math.isNan(fp16ToF32(nan_enc))); +} + +test "FP16: large values (full IEEE exponent)" { + const fp16 = f32ToFp16(100.0); + const back = fp16ToF32(fp16); + try std.testing.expect(@abs(back - 100.0) < 1.0); + + const fp16_big = f32ToFp16(65000.0); + const back_big = fp16ToF32(fp16_big); + try std.testing.expect(back_big > 60000.0); + try std.testing.expect(back_big < 65536.0); +} + +test "FP16: overflow to infinity" { + const fp16 = f32ToFp16(1e10); + try std.testing.expectEqual(@as(u16, 0x7C00), fp16); + try std.testing.expect(std.math.isInf(fp16ToF32(fp16))); +} + +test "FP16: roundtrip 1.0 exact" { + const fp16 = f32ToFp16(1.0); + try std.testing.expectEqual(@as(u16, 0x3C00), fp16); + try std.testing.expectEqual(@as(f32, 1.0), fp16ToF32(fp16)); +} + +test "FP16: denormal roundtrip" { + const small = fp16ToF32(@as(u16, 0x0001)); + try std.testing.expect(small > 0.0); + try std.testing.expect(small < 0.001); +} + test "BF16: special values" { try std.testing.expectEqual(@as(u16, 0x3F80), f32ToBf16(1.0)); try std.testing.expect(bf16ToF32(f32ToBf16(std.math.inf(f32))) > 1e30); try std.testing.expect(std.math.isNan(bf16ToF32(f32ToBf16(std.math.nan(f32))))); try std.testing.expectEqual(@as(u16, 0), f32ToBf16(0.0)); try std.testing.expectEqual(@as(u16, 0x8000), f32ToBf16(-0.0)); + try std.testing.expectEqual(@as(u16, 0x7F80), f32ToBf16(std.math.inf(f32))); + try std.testing.expectEqual(@as(u16, 0xFF80), f32ToBf16(-std.math.inf(f32))); +} + +test "BF16: large values do not flush" { + const bf16_1e10 = f32ToBf16(1e10); + const back_1e10 = bf16ToF32(bf16_1e10); + try std.testing.expect(back_1e10 > 5e9); + try std.testing.expect(back_1e10 < 2e10); + + const bf16_1e_10 = f32ToBf16(1e-10); + const back_1e_10 = bf16ToF32(bf16_1e_10); + try std.testing.expect(back_1e_10 > 5e-11); + try std.testing.expect(back_1e_10 < 2e-9); +} + +test "BF16: quantizeValue roundtrip all formats" { + const test_val: f32 = 42.0; + const gf16_round = quantizeValue(test_val, .gf16); + const bf16_round = quantizeValue(test_val, .bf16); + const fp16_round = quantizeValue(test_val, .fp16); + try std.testing.expect(@abs(gf16_round - test_val) / test_val < 0.05); + try std.testing.expect(@abs(bf16_round - test_val) / test_val < 0.05); + try std.testing.expect(@abs(fp16_round - test_val) / test_val < 0.05); }