diff --git a/src/c/gf16.h b/src/c/gf16.h index df914f3..43fecf8 100644 --- a/src/c/gf16.h +++ b/src/c/gf16.h @@ -347,6 +347,39 @@ gf16_t gf16_max(gf16_t a, gf16_t b); */ gf16_t gf16_fma(gf16_t a, gf16_t b, gf16_t c); +/** + * Fused multiply-subtract: a × b - c + * + * @param a First operand + * @param b Second operand + * @param c Third operand + * @return a × b - c in GF16 + */ +gf16_t gf16_fms(gf16_t a, gf16_t b, gf16_t c); + +/** + * Fused negated multiply-add: c - a × b + * + * @param a First operand + * @param b Second operand + * @param c Third operand + * @return c - a × b in GF16 + */ +gf16_t gf16_fnma(gf16_t a, gf16_t b, gf16_t c); + +/** + * φ-optimized FMA with golden-ratio scaling + * + * Applies φ-weighted dequantization in the fused operation + * for better distribution in ML weight space. + * + * @param a First operand + * @param b Second operand + * @param c Third operand + * @return φ-weighted fused result in GF16 + */ +gf16_t gf16_phi_fma(gf16_t a, gf16_t b, gf16_t c); + /*====================================================================== * Constants *======================================================================*/ diff --git a/src/c_abi.zig b/src/c_abi.zig index e5da396..65742d4 100644 --- a/src/c_abi.zig +++ b/src/c_abi.zig @@ -201,11 +201,31 @@ export fn gf16_max(a: gf16_t, b: gf16_t) callconv(.c) gf16_t { } export fn gf16_fma(a: gf16_t, b: gf16_t, c: gf16_t) callconv(.c) gf16_t { - // Compute a * b + c in f32, then round to GF16 - const fa = rawToGf16(a).toF32(); - const fb = rawToGf16(b).toF32(); - const fc = rawToGf16(c).toF32(); - return gf16ToRaw(golden.GF16.fromF32(fa * fb + fc)); + const gf_a = rawToGf16(a); + const gf_b = rawToGf16(b); + const gf_c = rawToGf16(c); + return gf16ToRaw(golden.GF16.fma(gf_a, gf_b, gf_c)); +} + +export fn gf16_fms(a: gf16_t, b: gf16_t, c: gf16_t) callconv(.c) gf16_t { + const gf_a = rawToGf16(a); + const gf_b = rawToGf16(b); + const gf_c = rawToGf16(c); + return gf16ToRaw(golden.GF16.fms(gf_a, gf_b, gf_c)); +} + +export fn gf16_fnma(a: gf16_t, b: gf16_t, c: gf16_t) callconv(.c) gf16_t { + const gf_a = rawToGf16(a); + const gf_b = rawToGf16(b); + const gf_c = rawToGf16(c); + return gf16ToRaw(golden.GF16.fnma(gf_a, gf_b, gf_c)); +} + +export fn gf16_phi_fma(a: gf16_t, b: gf16_t, c: gf16_t) callconv(.c) gf16_t { + const gf_a = rawToGf16(a); + const gf_b = rawToGf16(b); + const gf_c = rawToGf16(c); + return gf16ToRaw(golden.GF16.phiFma(gf_a, gf_b, gf_c)); } // ═══════════════════════════════════════════════════════════════════ diff --git a/src/formats/golden_float16.zig b/src/formats/golden_float16.zig index 71db419..64196d3 100644 --- a/src/formats/golden_float16.zig +++ b/src/formats/golden_float16.zig @@ -190,6 +190,39 @@ pub const GF16 = packed struct(u16) { pub fn phiDequantize(gf: GF16) f32 { return gf.toF32() * PHI_SQ; } + + /// Fused multiply-add: a * b + c (single rounding) + pub fn fma(a: GF16, b: GF16, c: GF16) GF16 { + return fromF32(a.toF32() * b.toF32() + c.toF32()); + } + + /// Fused multiply-subtract: a * b - c (single rounding) + pub fn fms(a: GF16, b: GF16, c: GF16) GF16 { + return fromF32(a.toF32() * b.toF32() - c.toF32()); + } + + /// Fused negated multiply-add: -(a * b) + c = c - a*b (single rounding) + pub fn fnma(a: GF16, b: GF16, c: GF16) GF16 { + return fromF32(c.toF32() - a.toF32() * b.toF32()); + } + + /// φ-optimized FMA: phiDequantize(a) * phiDequantize(b) + phiDequantize(c) + /// with φ-weighted scaling applied in single rounding step + pub fn phiFma(a: GF16, b: GF16, c: GF16) GF16 { + const scale = PHI_SQ * PHI_SQ; + const sum = a.toF32() * b.toF32() * scale + c.toF32() * PHI_SQ; + return fromF32(sum / PHI_SQ); + } + + /// φ-optimized dot product over slices (FMA accumulator) + pub fn phiDot(a: []const GF16, b: []const GF16) GF16 { + std.debug.assert(a.len == b.len); + var acc: f32 = 0.0; + for (a, b) |ai, bi| { + acc += ai.toF32() * bi.toF32(); + } + return fromF32(acc * PHI_INV_SQ); + } }; // ═════════════════════════════════════════════════════════════════════════════ @@ -370,6 +403,51 @@ test "GF16 arithmetic" { try std.testing.expectApproxEqAbs(@as(f32, 0.6), quot.toF32(), 0.05); } +test "GF16 FMA" { + const a = GF16.fromF32(2.0); + const b = GF16.fromF32(3.0); + const c = GF16.fromF32(1.0); + const result = GF16.fma(a, b, c); + try std.testing.expectApproxEqAbs(@as(f32, 7.0), result.toF32(), 0.1); +} + +test "GF16 FMS" { + const a = GF16.fromF32(2.0); + const b = GF16.fromF32(3.0); + const c = GF16.fromF32(1.0); + const result = GF16.fms(a, b, c); + try std.testing.expectApproxEqAbs(@as(f32, 5.0), result.toF32(), 0.1); +} + +test "GF16 FNMA" { + const a = GF16.fromF32(2.0); + const b = GF16.fromF32(3.0); + const c = GF16.fromF32(10.0); + const result = GF16.fnma(a, b, c); + try std.testing.expectApproxEqAbs(@as(f32, 4.0), result.toF32(), 0.1); +} + +test "GF16 phiFMA" { + const a = GF16.fromF32(1.0); + const b = GF16.fromF32(1.0); + const c = GF16.fromF32(0.0); + const result = GF16.phiFma(a, b, c); + const deq = result.phiDequantize(); + try std.testing.expect(deq > 1.0 and deq < 10.0); +} + +test "GF16 phiDot product" { + const a_vals = [_]f32{ 1.0, 2.0, 3.0 }; + const b_vals = [_]f32{ 1.0, 1.0, 1.0 }; + var a_gf: [3]GF16 = undefined; + var b_gf: [3]GF16 = undefined; + for (&a_gf, a_vals) |*g, v| g.* = GF16.fromF32(v); + for (&b_gf, b_vals) |*g, v| g.* = GF16.fromF32(v); + const result = GF16.phiDot(&a_gf, &b_gf); + const back = result.phiDequantize(); + try std.testing.expect(back > 4.0 and back < 8.0); +} + test "GF16 phi quantization roundtrip" { const original = 2.71828; const quantized = GF16.phiQuantize(original);