Skip to content

Commit f59254c

Browse files
committed
fix(format): FP16 and BF16 canonical IEEE-754 bit-cast codecs
FP16 bugs fixed: - f32ToFp16: replace broken frexp-based encoder with direct bit manipulation - fp16ToF32: fix double-add bug (was computing 1 + (1 + m/1024)), fix denormal exponent (was -16 instead of -14), add Inf/NaN decoding BF16 bugs fixed (same as PR #47): - Replace frexp with (bits +| rounding) >> 16 canonical implementation - Full IEEE-754 exponent range (-126..+127) instead of clamped ±7 Add comprehensive test coverage for FP16 and BF16: special values, large values, denormals, overflow, quantizeValue round-trips. Related: #29, #22
1 parent e2c100d commit f59254c

1 file changed

Lines changed: 113 additions & 41 deletions

File tree

src/formats/formats_root.zig

Lines changed: 113 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -118,62 +118,62 @@ pub fn f32ToGf16(a: f32) u16 {
118118
// ═══════════════════════════════════════════════════════════════════
119119
// Software fp16 encode/decode (IEEE 754 binary16)
120120
fn f32ToFp16(a: f32) u16 {
121-
if (a == 0) return 0;
122-
if (std.math.isInf(a)) return 0x7C00; // Infinity
123-
if (std.math.isNan(a)) return 0x7E00; // NaN
121+
if (std.math.isNan(a)) return 0x7E00;
122+
const bits: u32 = @bitCast(a);
123+
const sign: u16 = @intCast((bits >> 16) & 0x8000);
124+
const abs_bits = bits & 0x7FFFFFFF;
124125

125-
const sign_bit: u16 = if (a < 0) 0x8000 else 0;
126-
const abs_a = if (a < 0) -a else a;
126+
if (abs_bits == 0) return sign;
127+
if (std.math.isInf(a)) return sign | 0x7C00;
127128

128-
const frexp_result = std.math.frexp(abs_a);
129-
const m_val = frexp_result.significand * 2.0;
130-
var e = frexp_result.exponent - 1;
129+
const f32_exp = @as(i32, @intCast((abs_bits >> 23) & 0xFF)) - 127;
130+
const f32_mant = abs_bits & 0x7FFFFF;
131131

132-
e = @min(e, 15);
133-
if (e <= -10) {
134-
// Underflow -> zero
135-
return sign_bit;
136-
}
132+
if (f32_exp > 15) return sign | 0x7C00;
137133

138-
const mant_f = (m_val - 1.0) * 1024.0; // 2^10
139-
var mant_i = @as(i32, @intFromFloat(mant_f));
140-
141-
if (mant_i == 1024) {
142-
mant_i = 1023;
143-
e += 1;
144-
if (e >= 31) return 0x7C00; // Overflow
134+
if (f32_exp >= -14) {
135+
const fp16_mant = @as(u16, @intCast(f32_mant >> 13));
136+
const fp16_exp = @as(u16, @intCast(f32_exp + 15)) << 10;
137+
return sign | fp16_exp | fp16_mant;
145138
}
146-
const mant_bits: u16 = @as(u16, @intCast(mant_i)) & 0x03FF;
147-
const e_bits: u16 = @as(u16, @intCast(e + 15)) << 10;
148139

149-
return sign_bit | e_bits | mant_bits;
140+
const shift = @as(u5, @intCast(@as(i32, 13) - f32_exp - 14 + 1));
141+
if (shift >= 32) return sign;
142+
const fp16_mant = @as(u16, @intCast(f32_mant >> shift));
143+
if (fp16_mant == 0) return sign;
144+
return sign | fp16_mant;
150145
}
151146

152147
fn fp16ToF32(x: u16) f32 {
153-
if (x == 0) return 0.0;
154-
if (x == 0x8000) return -0.0;
155-
156-
const sign = @as(i32, (x >> 15) & 0x1);
157-
const e = @as(i32, (x >> 10) & 0x1F);
158-
const m = @as(i32, x & 0x03FF);
148+
const sign: u32 = @as(u32, x & 0x8000) << 16;
149+
const e = (x >> 10) & 0x1F;
150+
const m = x & 0x03FF;
159151

160152
if (e == 0) {
161-
// Denormal: m in [1, 1023], value = m * 2^(-14)
162-
const frac = @as(f32, @floatFromInt(m)) / 1024.0;
163-
const exp = @as(f32, @floatFromInt(e - 1 - 15));
164-
const val = frac * std.math.pow(f32, 2.0, exp);
165-
return if (sign != 0) -val else val;
166-
} else {
167-
const frac = @as(f32, @floatFromInt(m + 1024)) / 1024.0;
168-
const exp = @as(f32, @floatFromInt(e - 15));
169-
const val = (1.0 + frac) * std.math.pow(f32, 2.0, exp);
170-
return if (sign != 0) -val else val;
153+
if (m == 0) return @bitCast(sign);
154+
var mant = @as(u32, m) << 13;
155+
var exp: u32 = 0;
156+
while ((mant & 0x00800000) == 0) : (exp -= 1) {
157+
mant <<= 1;
158+
}
159+
const f32_bits = sign | ((112 - exp) << 23) | (mant & 0x7FFFFF);
160+
return @bitCast(f32_bits);
161+
}
162+
if (e == 0x1F) {
163+
if (m == 0) return @bitCast(sign | 0x7F800000);
164+
return @bitCast(sign | 0x7FC00000);
171165
}
166+
167+
const f32_bits = sign | ((@as(u32, e) + 112) << 23) | (@as(u32, m) << 13);
168+
return @bitCast(f32_bits);
172169
}
173170

174-
// Software bf16 encode/decode (Brain Float 16)
171+
// Software bf16 encode/decode (Brain Float 16) — IEEE 754 canonical
175172
fn f32ToBf16(a: f32) u16 {
176-
return @intCast(@as(u32, @bitCast(a)) >> 16);
173+
if (std.math.isNan(a)) return 0x7FC0;
174+
const bits: u32 = @bitCast(a);
175+
const rounding: u32 = ((bits >> 16) & 1) + 0x7FFF;
176+
return @intCast((bits +| rounding) >> 16);
177177
}
178178

179179
fn bf16ToF32(x: u16) f32 {
@@ -568,10 +568,82 @@ test "BF16: roundtrip small values" {
568568
}
569569
}
570570

571+
test "FP16: roundtrip basic values" {
572+
const values = [_]f32{ 1.0, -1.0, 0.5, -0.5, 2.0, -2.0, 0.1, 0.25, 1.5 };
573+
for (values) |v| {
574+
const fp16 = f32ToFp16(v);
575+
const recovered = fp16ToF32(fp16);
576+
const err = @abs(recovered - v) / @max(@abs(v), 1e-30);
577+
try std.testing.expect(err < 0.005);
578+
}
579+
}
580+
581+
test "FP16: special values" {
582+
try std.testing.expectEqual(@as(u16, 0x0000), f32ToFp16(0.0));
583+
try std.testing.expectEqual(@as(u16, 0x8000), f32ToFp16(-0.0));
584+
try std.testing.expectEqual(@as(u16, 0x7C00), f32ToFp16(std.math.inf(f32)));
585+
try std.testing.expectEqual(@as(u16, 0xFC00), f32ToFp16(-std.math.inf(f32)));
586+
const nan_enc = f32ToFp16(std.math.nan(f32));
587+
try std.testing.expect(std.math.isNan(fp16ToF32(nan_enc)));
588+
}
589+
590+
test "FP16: large values (full IEEE exponent)" {
591+
const fp16 = f32ToFp16(100.0);
592+
const back = fp16ToF32(fp16);
593+
try std.testing.expect(@abs(back - 100.0) < 1.0);
594+
595+
const fp16_big = f32ToFp16(65000.0);
596+
const back_big = fp16ToF32(fp16_big);
597+
try std.testing.expect(back_big > 60000.0);
598+
try std.testing.expect(back_big < 65536.0);
599+
}
600+
601+
test "FP16: overflow to infinity" {
602+
const fp16 = f32ToFp16(1e10);
603+
try std.testing.expectEqual(@as(u16, 0x7C00), fp16);
604+
try std.testing.expect(std.math.isInf(fp16ToF32(fp16)));
605+
}
606+
607+
test "FP16: roundtrip 1.0 exact" {
608+
const fp16 = f32ToFp16(1.0);
609+
try std.testing.expectEqual(@as(u16, 0x3C00), fp16);
610+
try std.testing.expectEqual(@as(f32, 1.0), fp16ToF32(fp16));
611+
}
612+
613+
test "FP16: denormal roundtrip" {
614+
const small = fp16ToF32(@as(u16, 0x0001));
615+
try std.testing.expect(small > 0.0);
616+
try std.testing.expect(small < 0.001);
617+
}
618+
571619
test "BF16: special values" {
572620
try std.testing.expectEqual(@as(u16, 0x3F80), f32ToBf16(1.0));
573621
try std.testing.expect(bf16ToF32(f32ToBf16(std.math.inf(f32))) > 1e30);
574622
try std.testing.expect(std.math.isNan(bf16ToF32(f32ToBf16(std.math.nan(f32)))));
575623
try std.testing.expectEqual(@as(u16, 0), f32ToBf16(0.0));
576624
try std.testing.expectEqual(@as(u16, 0x8000), f32ToBf16(-0.0));
625+
try std.testing.expectEqual(@as(u16, 0x7F80), f32ToBf16(std.math.inf(f32)));
626+
try std.testing.expectEqual(@as(u16, 0xFF80), f32ToBf16(-std.math.inf(f32)));
627+
}
628+
629+
test "BF16: large values do not flush" {
630+
const bf16_1e10 = f32ToBf16(1e10);
631+
const back_1e10 = bf16ToF32(bf16_1e10);
632+
try std.testing.expect(back_1e10 > 5e9);
633+
try std.testing.expect(back_1e10 < 2e10);
634+
635+
const bf16_1e_10 = f32ToBf16(1e-10);
636+
const back_1e_10 = bf16ToF32(bf16_1e_10);
637+
try std.testing.expect(back_1e_10 > 5e-11);
638+
try std.testing.expect(back_1e_10 < 2e-9);
639+
}
640+
641+
test "BF16: quantizeValue roundtrip all formats" {
642+
const test_val: f32 = 42.0;
643+
const gf16_round = quantizeValue(test_val, .gf16);
644+
const bf16_round = quantizeValue(test_val, .bf16);
645+
const fp16_round = quantizeValue(test_val, .fp16);
646+
try std.testing.expect(@abs(gf16_round - test_val) / test_val < 0.05);
647+
try std.testing.expect(@abs(bf16_round - test_val) / test_val < 0.05);
648+
try std.testing.expect(@abs(fp16_round - test_val) / test_val < 0.05);
577649
}

0 commit comments

Comments
 (0)