Skip to content

Commit 601e65d

Browse files
authored
fix(bf16): rewrite f32ToBf16/bf16ToF32 as standard IEEE 754 bit ops (#53)
Previous implementation had multiple bugs: - Exponent clamped to ±7 instead of IEEE-754 ±127 (7-bit range vs 8-bit) - Wrong mantissa width (8 bits instead of 7) - Wrong bit layout for decode (7-bit exp + 8-bit mantissa instead of 8+7) - Wrong bias handling in frexp path Standard BF16 is simply the top 16 bits of an IEEE 754 f32: [S:1][E:8][M:7] = bits 31..16 of f32 This replaces 54 lines of broken frexp-based code with 2 lines of correct bit manipulation, matching every major BF16 implementation (PyTorch, TensorFlow, MLX, etc.). Added 5 new BF16 tests covering: 1.0, 100.0, 1e10, small values, and special values (inf, NaN, ±0). Closes #22
1 parent 63d9a1d commit 601e65d

1 file changed

Lines changed: 41 additions & 54 deletions

File tree

src/formats/formats_root.zig

Lines changed: 41 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -173,63 +173,11 @@ fn fp16ToF32(x: u16) f32 {
173173

174174
// Software bf16 encode/decode (Brain Float 16)
175175
fn f32ToBf16(a: f32) u16 {
176-
if (a == 0) return 0;
177-
if (std.math.isInf(a)) return 0x7F80; // Infinity (all ones)
178-
if (std.math.isNan(a)) return 0x7FC0; // NaN
179-
180-
const sign_bit: u16 = if (a < 0) 0x8000 else 0;
181-
const abs_a = if (a < 0) -a else a;
182-
183-
const frexp_result = std.math.frexp(abs_a);
184-
const m_val = frexp_result.significand;
185-
var e = frexp_result.exponent - 127;
186-
187-
if (e < -7) {
188-
// Denormalized range -> flush to zero
189-
return sign_bit;
190-
}
191-
192-
e = @min(e, 7);
193-
if (e <= 0 and m_val < 0.5) {
194-
return sign_bit; // Subnormal -> zero
195-
}
196-
197-
const mant_f = (m_val - 1.0) * 256.0; // 2^8
198-
var mant_i = @as(i32, @intFromFloat(mant_f));
199-
200-
if (mant_i == 256) {
201-
mant_i = 255;
202-
e += 1;
203-
if (e >= 7) return 0x7F80; // Overflow
204-
}
205-
206-
const mant_bits: u16 = @as(u16, @intCast(mant_i)) & 0x00FF;
207-
const e_bits: u16 = @as(u16, @intCast(e)) << 7;
208-
209-
return sign_bit | e_bits | mant_bits;
176+
return @intCast(@as(u32, @bitCast(a)) >> 16);
210177
}
211178

212179
fn bf16ToF32(x: u16) f32 {
213-
if (x == 0) return 0.0;
214-
if (x == 0x8000) return -0.0;
215-
216-
const sign = @as(i32, (x >> 15) & 0x1);
217-
const e = @as(i32, (x >> 7) & 0x7F);
218-
const m = @as(i32, x & 0x00FF);
219-
220-
if (e == 0) {
221-
// Denormalized: value = m * 2^(-126)
222-
const frac = @as(f32, @floatFromInt(m)) / 256.0;
223-
const exp = @as(f32, @floatFromInt(e - 1 - 127));
224-
const val = frac * std.math.pow(f32, 2.0, exp);
225-
return if (sign != 0) -val else val;
226-
} else {
227-
// Normal: value = (1 + m/256) * 2^(e-127)
228-
const frac = @as(f32, @floatFromInt(m)) / 256.0;
229-
const exp = @as(f32, @floatFromInt(e - 127));
230-
const val = (1.0 + frac) * std.math.pow(f32, 2.0, exp);
231-
return if (sign != 0) -val else val;
232-
}
180+
return @bitCast(@as(u32, x) << 16);
233181
}
234182

235183
// ═══════════════════════════════════════════════════════════════════
@@ -588,3 +536,42 @@ test "formatBytes" {
588536
try std.testing.expectEqual(@as(usize, 2), formatBytes(.gf16));
589537
try std.testing.expectEqual(@as(usize, 1), formatBytes(.ternary));
590538
}
539+
540+
test "BF16: roundtrip 1.0" {
541+
const bf16 = f32ToBf16(1.0);
542+
try std.testing.expectEqual(@as(u16, 0x3F80), bf16);
543+
const back = bf16ToF32(bf16);
544+
try std.testing.expectEqual(@as(f32, 1.0), back);
545+
}
546+
547+
test "BF16: roundtrip 100.0" {
548+
const bf16 = f32ToBf16(100.0);
549+
const back = bf16ToF32(bf16);
550+
const err = @abs(back - 100.0);
551+
try std.testing.expect(err < 1.0);
552+
}
553+
554+
test "BF16: roundtrip 1e10" {
555+
const bf16 = f32ToBf16(1e10);
556+
const back = bf16ToF32(bf16);
557+
const err = @abs(back - 1e10) / 1e10;
558+
try std.testing.expect(err < 0.01);
559+
}
560+
561+
test "BF16: roundtrip small values" {
562+
const values = [_]f32{ 0.5, -0.5, 2.0, -2.0, 3.14, -3.14, 1e-10, -1e-10 };
563+
for (values) |v| {
564+
const bf16 = f32ToBf16(v);
565+
const back = bf16ToF32(bf16);
566+
const err = if (@abs(v) > 0.001) @abs(back - v) / @abs(v) else @abs(back - v);
567+
try std.testing.expect(err < 0.01);
568+
}
569+
}
570+
571+
test "BF16: special values" {
572+
try std.testing.expectEqual(@as(u16, 0x3F80), f32ToBf16(1.0));
573+
try std.testing.expect(bf16ToF32(f32ToBf16(std.math.inf(f32))) > 1e30);
574+
try std.testing.expect(std.math.isNan(bf16ToF32(f32ToBf16(std.math.nan(f32)))));
575+
try std.testing.expectEqual(@as(u16, 0), f32ToBf16(0.0));
576+
try std.testing.expectEqual(@as(u16, 0x8000), f32ToBf16(-0.0));
577+
}

0 commit comments

Comments
 (0)