diff --git a/src/c/gf16.h b/src/c/gf16.h index df914f3..a9d3808 100644 --- a/src/c/gf16.h +++ b/src/c/gf16.h @@ -347,6 +347,32 @@ gf16_t gf16_max(gf16_t a, gf16_t b); */ gf16_t gf16_fma(gf16_t a, gf16_t b, gf16_t c); +/** + * φ-optimized fused multiply-add + * + * Dequantizes inputs from φ-space, computes a × b + c in f32, + * then φ-quantizes the result back. + * + * @param a First operand (φ-quantized) + * @param b Second operand (φ-quantized) + * @param c Third operand (φ-quantized) + * @return φ-quantized result of a × b + c + */ +gf16_t gf16_phi_fma(gf16_t a, gf16_t b, gf16_t c); + +/** + * φ-optimized fused multiply-subtract + * + * Dequantizes inputs from φ-space, computes a × b - c in f32, + * then φ-quantizes the result back. + * + * @param a First operand (φ-quantized) + * @param b Second operand (φ-quantized) + * @param c Third operand (φ-quantized) + * @return φ-quantized result of a × b - c + */ +gf16_t gf16_phi_fms(gf16_t a, gf16_t b, gf16_t c); + /*====================================================================== * Constants *======================================================================*/ diff --git a/src/c_abi.zig b/src/c_abi.zig index e5da396..e619085 100644 --- a/src/c_abi.zig +++ b/src/c_abi.zig @@ -201,13 +201,20 @@ export fn gf16_max(a: gf16_t, b: gf16_t) callconv(.c) gf16_t { } export fn gf16_fma(a: gf16_t, b: gf16_t, c: gf16_t) callconv(.c) gf16_t { - // Compute a * b + c in f32, then round to GF16 const fa = rawToGf16(a).toF32(); const fb = rawToGf16(b).toF32(); const fc = rawToGf16(c).toF32(); return gf16ToRaw(golden.GF16.fromF32(fa * fb + fc)); } +export fn gf16_phi_fma(a: gf16_t, b: gf16_t, c: gf16_t) callconv(.c) gf16_t { + return gf16ToRaw(golden.GF16.phiFma(rawToGf16(a), rawToGf16(b), rawToGf16(c))); +} + +export fn gf16_phi_fms(a: gf16_t, b: gf16_t, c: gf16_t) callconv(.c) gf16_t { + return gf16ToRaw(golden.GF16.phiFms(rawToGf16(a), rawToGf16(b), rawToGf16(c))); +} + // ═══════════════════════════════════════════════════════════════════ // Library Info // ═════════════════════════════════════════════════════════════════════ @@ -309,6 +316,24 @@ test "C-ABI: gf16_fma" { try std.testing.expectApproxEqAbs(@as(f32, 10.0), val, 0.05); } +test "C-ABI: gf16_phi_fma" { + const a = gf16_phi_quantize(2.0); + const b = gf16_phi_quantize(3.0); + const c = gf16_phi_quantize(4.0); + const result = gf16_phi_fma(a, b, c); + const deq = gf16_phi_dequantize(result); + try std.testing.expectApproxEqAbs(@as(f32, 10.0), deq, 1.5); +} + +test "C-ABI: gf16_phi_fms" { + const a = gf16_phi_quantize(5.0); + const b = gf16_phi_quantize(3.0); + const c = gf16_phi_quantize(4.0); + const result = gf16_phi_fms(a, b, c); + const deq = gf16_phi_dequantize(result); + try std.testing.expectApproxEqAbs(@as(f32, 11.0), deq, 2.0); +} + test "C-ABI: library version" { const version = std.mem.span(goldenfloat_version()); try std.testing.expectEqualStrings("1.1.0", version); diff --git a/src/formats/golden_float16.zig b/src/formats/golden_float16.zig index 71db419..d04d1fe 100644 --- a/src/formats/golden_float16.zig +++ b/src/formats/golden_float16.zig @@ -190,6 +190,27 @@ pub const GF16 = packed struct(u16) { pub fn phiDequantize(gf: GF16) f32 { return gf.toF32() * PHI_SQ; } + + /// φ-optimized fused multiply-add: dequantize(a)*dequantize(b) + dequantize(c), then φ-quantize + pub fn phiFma(a: GF16, b: GF16, c: GF16) GF16 { + const fa = phiDequantize(a); + const fb = phiDequantize(b); + const fc = phiDequantize(c); + return phiQuantize(fa * fb + fc); + } + + /// φ-optimized fused multiply-subtract: dequantize(a)*dequantize(b) - dequantize(c), then φ-quantize + pub fn phiFms(a: GF16, b: GF16, c: GF16) GF16 { + const fa = phiDequantize(a); + const fb = phiDequantize(b); + const fc = phiDequantize(c); + return phiQuantize(fa * fb - fc); + } + + /// Standard fused multiply-add (no φ scaling): a*b + c in f32, rounded to GF16 + pub fn fma(a: GF16, b: GF16, c: GF16) GF16 { + return fromF32(a.toF32() * b.toF32() + c.toF32()); + } }; // ═════════════════════════════════════════════════════════════════════════════ @@ -417,4 +438,30 @@ test "PHI_SQ + 1/PHI_SQ equals 3" { try std.testing.expectApproxEqAbs(@as(f32, 3.0), computed, 1e-10); } +test "GF16 phi-fused multiply-add" { + const a = GF16.phiQuantize(2.0); + const b = GF16.phiQuantize(3.0); + const c = GF16.phiQuantize(4.0); + const result = GF16.phiFma(a, b, c); + const deq = GF16.phiDequantize(result); + try std.testing.expectApproxEqAbs(@as(f32, 10.0), deq, 1.5); +} + +test "GF16 phi-fused multiply-subtract" { + const a = GF16.phiQuantize(5.0); + const b = GF16.phiQuantize(3.0); + const c = GF16.phiQuantize(4.0); + const result = GF16.phiFms(a, b, c); + const deq = GF16.phiDequantize(result); + try std.testing.expectApproxEqAbs(@as(f32, 11.0), deq, 2.0); +} + +test "GF16 standard fused multiply-add" { + const a = GF16.fromF32(2.0); + const b = GF16.fromF32(3.0); + const c = GF16.fromF32(4.0); + const result = GF16.fma(a, b, c); + try std.testing.expectApproxEqAbs(@as(f32, 10.0), result.toF32(), 0.5); +} + // φ² + 1/φ² = 3 | TRINITY