Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/c/gf16.h
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,32 @@ gf16_t gf16_max(gf16_t a, gf16_t b);
*/
gf16_t gf16_fma(gf16_t a, gf16_t b, gf16_t c);

/**
* φ-optimized fused multiply-add
*
* Dequantizes inputs from φ-space, computes a × b + c in f32,
* then φ-quantizes the result back.
*
* @param a First operand (φ-quantized)
* @param b Second operand (φ-quantized)
* @param c Third operand (φ-quantized)
* @return φ-quantized result of a × b + c
*/
gf16_t gf16_phi_fma(gf16_t a, gf16_t b, gf16_t c);

/**
* φ-optimized fused multiply-subtract
*
* Dequantizes inputs from φ-space, computes a × b - c in f32,
* then φ-quantizes the result back.
*
* @param a First operand (φ-quantized)
* @param b Second operand (φ-quantized)
* @param c Third operand (φ-quantized)
* @return φ-quantized result of a × b - c
*/
gf16_t gf16_phi_fms(gf16_t a, gf16_t b, gf16_t c);

/*======================================================================
* Constants
*======================================================================*/
Expand Down
27 changes: 26 additions & 1 deletion src/c_abi.zig
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,20 @@ export fn gf16_max(a: gf16_t, b: gf16_t) callconv(.c) gf16_t {
}

export fn gf16_fma(a: gf16_t, b: gf16_t, c: gf16_t) callconv(.c) gf16_t {
// Compute a * b + c in f32, then round to GF16
const fa = rawToGf16(a).toF32();
const fb = rawToGf16(b).toF32();
const fc = rawToGf16(c).toF32();
return gf16ToRaw(golden.GF16.fromF32(fa * fb + fc));
}

export fn gf16_phi_fma(a: gf16_t, b: gf16_t, c: gf16_t) callconv(.c) gf16_t {
return gf16ToRaw(golden.GF16.phiFma(rawToGf16(a), rawToGf16(b), rawToGf16(c)));
}

export fn gf16_phi_fms(a: gf16_t, b: gf16_t, c: gf16_t) callconv(.c) gf16_t {
return gf16ToRaw(golden.GF16.phiFms(rawToGf16(a), rawToGf16(b), rawToGf16(c)));
}

// ═══════════════════════════════════════════════════════════════════
// Library Info
// ═════════════════════════════════════════════════════════════════════
Expand Down Expand Up @@ -309,6 +316,24 @@ test "C-ABI: gf16_fma" {
try std.testing.expectApproxEqAbs(@as(f32, 10.0), val, 0.05);
}

test "C-ABI: gf16_phi_fma" {
const a = gf16_phi_quantize(2.0);
const b = gf16_phi_quantize(3.0);
const c = gf16_phi_quantize(4.0);
const result = gf16_phi_fma(a, b, c);
const deq = gf16_phi_dequantize(result);
try std.testing.expectApproxEqAbs(@as(f32, 10.0), deq, 1.5);
}

test "C-ABI: gf16_phi_fms" {
const a = gf16_phi_quantize(5.0);
const b = gf16_phi_quantize(3.0);
const c = gf16_phi_quantize(4.0);
const result = gf16_phi_fms(a, b, c);
const deq = gf16_phi_dequantize(result);
try std.testing.expectApproxEqAbs(@as(f32, 11.0), deq, 2.0);
}

test "C-ABI: library version" {
const version = std.mem.span(goldenfloat_version());
try std.testing.expectEqualStrings("1.1.0", version);
Expand Down
47 changes: 47 additions & 0 deletions src/formats/golden_float16.zig
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,27 @@ pub const GF16 = packed struct(u16) {
pub fn phiDequantize(gf: GF16) f32 {
return gf.toF32() * PHI_SQ;
}

/// φ-optimized fused multiply-add: dequantize(a)*dequantize(b) + dequantize(c), then φ-quantize
pub fn phiFma(a: GF16, b: GF16, c: GF16) GF16 {
const fa = phiDequantize(a);
const fb = phiDequantize(b);
const fc = phiDequantize(c);
return phiQuantize(fa * fb + fc);
}

/// φ-optimized fused multiply-subtract: dequantize(a)*dequantize(b) - dequantize(c), then φ-quantize
pub fn phiFms(a: GF16, b: GF16, c: GF16) GF16 {
const fa = phiDequantize(a);
const fb = phiDequantize(b);
const fc = phiDequantize(c);
return phiQuantize(fa * fb - fc);
}

/// Standard fused multiply-add (no φ scaling): a*b + c in f32, rounded to GF16
pub fn fma(a: GF16, b: GF16, c: GF16) GF16 {
return fromF32(a.toF32() * b.toF32() + c.toF32());
}
};

// ═════════════════════════════════════════════════════════════════════════════
Expand Down Expand Up @@ -417,4 +438,30 @@ test "PHI_SQ + 1/PHI_SQ equals 3" {
try std.testing.expectApproxEqAbs(@as(f32, 3.0), computed, 1e-10);
}

test "GF16 phi-fused multiply-add" {
const a = GF16.phiQuantize(2.0);
const b = GF16.phiQuantize(3.0);
const c = GF16.phiQuantize(4.0);
const result = GF16.phiFma(a, b, c);
const deq = GF16.phiDequantize(result);
try std.testing.expectApproxEqAbs(@as(f32, 10.0), deq, 1.5);
}

test "GF16 phi-fused multiply-subtract" {
const a = GF16.phiQuantize(5.0);
const b = GF16.phiQuantize(3.0);
const c = GF16.phiQuantize(4.0);
const result = GF16.phiFms(a, b, c);
const deq = GF16.phiDequantize(result);
try std.testing.expectApproxEqAbs(@as(f32, 11.0), deq, 2.0);
}

test "GF16 standard fused multiply-add" {
const a = GF16.fromF32(2.0);
const b = GF16.fromF32(3.0);
const c = GF16.fromF32(4.0);
const result = GF16.fma(a, b, c);
try std.testing.expectApproxEqAbs(@as(f32, 10.0), result.toF32(), 0.5);
}

// φ² + 1/φ² = 3 | TRINITY
Loading