Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 113 additions & 41 deletions src/formats/formats_root.zig
Original file line number Diff line number Diff line change
Expand Up @@ -118,62 +118,62 @@ pub fn f32ToGf16(a: f32) u16 {
// ═══════════════════════════════════════════════════════════════════
// Software fp16 encode/decode (IEEE 754 binary16)
fn f32ToFp16(a: f32) u16 {
if (a == 0) return 0;
if (std.math.isInf(a)) return 0x7C00; // Infinity
if (std.math.isNan(a)) return 0x7E00; // NaN
if (std.math.isNan(a)) return 0x7E00;
const bits: u32 = @bitCast(a);
const sign: u16 = @intCast((bits >> 16) & 0x8000);
const abs_bits = bits & 0x7FFFFFFF;

const sign_bit: u16 = if (a < 0) 0x8000 else 0;
const abs_a = if (a < 0) -a else a;
if (abs_bits == 0) return sign;
if (std.math.isInf(a)) return sign | 0x7C00;

const frexp_result = std.math.frexp(abs_a);
const m_val = frexp_result.significand * 2.0;
var e = frexp_result.exponent - 1;
const f32_exp = @as(i32, @intCast((abs_bits >> 23) & 0xFF)) - 127;
const f32_mant = abs_bits & 0x7FFFFF;

e = @min(e, 15);
if (e <= -10) {
// Underflow -> zero
return sign_bit;
}
if (f32_exp > 15) return sign | 0x7C00;

const mant_f = (m_val - 1.0) * 1024.0; // 2^10
var mant_i = @as(i32, @intFromFloat(mant_f));

if (mant_i == 1024) {
mant_i = 1023;
e += 1;
if (e >= 31) return 0x7C00; // Overflow
if (f32_exp >= -14) {
const fp16_mant = @as(u16, @intCast(f32_mant >> 13));
const fp16_exp = @as(u16, @intCast(f32_exp + 15)) << 10;
return sign | fp16_exp | fp16_mant;
}
const mant_bits: u16 = @as(u16, @intCast(mant_i)) & 0x03FF;
const e_bits: u16 = @as(u16, @intCast(e + 15)) << 10;

return sign_bit | e_bits | mant_bits;
const shift = @as(u5, @intCast(@as(i32, 13) - f32_exp - 14 + 1));
if (shift >= 32) return sign;
const fp16_mant = @as(u16, @intCast(f32_mant >> shift));
if (fp16_mant == 0) return sign;
return sign | fp16_mant;
}

fn fp16ToF32(x: u16) f32 {
if (x == 0) return 0.0;
if (x == 0x8000) return -0.0;

const sign = @as(i32, (x >> 15) & 0x1);
const e = @as(i32, (x >> 10) & 0x1F);
const m = @as(i32, x & 0x03FF);
const sign: u32 = @as(u32, x & 0x8000) << 16;
const e = (x >> 10) & 0x1F;
const m = x & 0x03FF;

if (e == 0) {
// Denormal: m in [1, 1023], value = m * 2^(-14)
const frac = @as(f32, @floatFromInt(m)) / 1024.0;
const exp = @as(f32, @floatFromInt(e - 1 - 15));
const val = frac * std.math.pow(f32, 2.0, exp);
return if (sign != 0) -val else val;
} else {
const frac = @as(f32, @floatFromInt(m + 1024)) / 1024.0;
const exp = @as(f32, @floatFromInt(e - 15));
const val = (1.0 + frac) * std.math.pow(f32, 2.0, exp);
return if (sign != 0) -val else val;
if (m == 0) return @bitCast(sign);
var mant = @as(u32, m) << 13;
var exp: u32 = 0;
while ((mant & 0x00800000) == 0) : (exp -= 1) {
mant <<= 1;
}
const f32_bits = sign | ((112 - exp) << 23) | (mant & 0x7FFFFF);
return @bitCast(f32_bits);
}
if (e == 0x1F) {
if (m == 0) return @bitCast(sign | 0x7F800000);
return @bitCast(sign | 0x7FC00000);
}

const f32_bits = sign | ((@as(u32, e) + 112) << 23) | (@as(u32, m) << 13);
return @bitCast(f32_bits);
}

// Software bf16 encode/decode (Brain Float 16)
// Software bf16 encode/decode (Brain Float 16) — IEEE 754 canonical
fn f32ToBf16(a: f32) u16 {
return @intCast(@as(u32, @bitCast(a)) >> 16);
if (std.math.isNan(a)) return 0x7FC0;
const bits: u32 = @bitCast(a);
const rounding: u32 = ((bits >> 16) & 1) + 0x7FFF;
return @intCast((bits +| rounding) >> 16);
}

fn bf16ToF32(x: u16) f32 {
Expand Down Expand Up @@ -568,10 +568,82 @@ test "BF16: roundtrip small values" {
}
}

test "FP16: roundtrip basic values" {
const values = [_]f32{ 1.0, -1.0, 0.5, -0.5, 2.0, -2.0, 0.1, 0.25, 1.5 };
for (values) |v| {
const fp16 = f32ToFp16(v);
const recovered = fp16ToF32(fp16);
const err = @abs(recovered - v) / @max(@abs(v), 1e-30);
try std.testing.expect(err < 0.005);
}
}

test "FP16: special values" {
try std.testing.expectEqual(@as(u16, 0x0000), f32ToFp16(0.0));
try std.testing.expectEqual(@as(u16, 0x8000), f32ToFp16(-0.0));
try std.testing.expectEqual(@as(u16, 0x7C00), f32ToFp16(std.math.inf(f32)));
try std.testing.expectEqual(@as(u16, 0xFC00), f32ToFp16(-std.math.inf(f32)));
const nan_enc = f32ToFp16(std.math.nan(f32));
try std.testing.expect(std.math.isNan(fp16ToF32(nan_enc)));
}

test "FP16: large values (full IEEE exponent)" {
const fp16 = f32ToFp16(100.0);
const back = fp16ToF32(fp16);
try std.testing.expect(@abs(back - 100.0) < 1.0);

const fp16_big = f32ToFp16(65000.0);
const back_big = fp16ToF32(fp16_big);
try std.testing.expect(back_big > 60000.0);
try std.testing.expect(back_big < 65536.0);
}

test "FP16: overflow to infinity" {
const fp16 = f32ToFp16(1e10);
try std.testing.expectEqual(@as(u16, 0x7C00), fp16);
try std.testing.expect(std.math.isInf(fp16ToF32(fp16)));
}

test "FP16: roundtrip 1.0 exact" {
const fp16 = f32ToFp16(1.0);
try std.testing.expectEqual(@as(u16, 0x3C00), fp16);
try std.testing.expectEqual(@as(f32, 1.0), fp16ToF32(fp16));
}

test "FP16: denormal roundtrip" {
const small = fp16ToF32(@as(u16, 0x0001));
try std.testing.expect(small > 0.0);
try std.testing.expect(small < 0.001);
}

test "BF16: special values" {
try std.testing.expectEqual(@as(u16, 0x3F80), f32ToBf16(1.0));
try std.testing.expect(bf16ToF32(f32ToBf16(std.math.inf(f32))) > 1e30);
try std.testing.expect(std.math.isNan(bf16ToF32(f32ToBf16(std.math.nan(f32)))));
try std.testing.expectEqual(@as(u16, 0), f32ToBf16(0.0));
try std.testing.expectEqual(@as(u16, 0x8000), f32ToBf16(-0.0));
try std.testing.expectEqual(@as(u16, 0x7F80), f32ToBf16(std.math.inf(f32)));
try std.testing.expectEqual(@as(u16, 0xFF80), f32ToBf16(-std.math.inf(f32)));
}

test "BF16: large values do not flush" {
const bf16_1e10 = f32ToBf16(1e10);
const back_1e10 = bf16ToF32(bf16_1e10);
try std.testing.expect(back_1e10 > 5e9);
try std.testing.expect(back_1e10 < 2e10);

const bf16_1e_10 = f32ToBf16(1e-10);
const back_1e_10 = bf16ToF32(bf16_1e_10);
try std.testing.expect(back_1e_10 > 5e-11);
try std.testing.expect(back_1e_10 < 2e-9);
}

test "BF16: quantizeValue roundtrip all formats" {
const test_val: f32 = 42.0;
const gf16_round = quantizeValue(test_val, .gf16);
const bf16_round = quantizeValue(test_val, .bf16);
const fp16_round = quantizeValue(test_val, .fp16);
try std.testing.expect(@abs(gf16_round - test_val) / test_val < 0.05);
try std.testing.expect(@abs(bf16_round - test_val) / test_val < 0.05);
try std.testing.expect(@abs(fp16_round - test_val) / test_val < 0.05);
}
Loading