Skip to content

Commit 145c8aa

Browse files
committed
feat: φ-optimized FMA operations for GF16 (issue #4)
Add fused arithmetic to GF16 struct: - fma(a,b,c): a*b+c (single rounding) - fms(a,b,c): a*b-c - fnma(a,b,c): c-a*b (negated FMA) - phiFma(a,b,c): φ-weighted FMA with golden-ratio scaling - phiDot(a,b): φ-optimized dot product over slices Export via C-ABI (gf16_fms, gf16_fnma, gf16_phi_fma) and declare in gf16.h header. Add 5 new unit tests. Closes #4
1 parent be611c7 commit 145c8aa

3 files changed

Lines changed: 136 additions & 5 deletions

File tree

src/c/gf16.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,39 @@ gf16_t gf16_max(gf16_t a, gf16_t b);
347347
*/
348348
gf16_t gf16_fma(gf16_t a, gf16_t b, gf16_t c);
349349

350+
/**
351+
* Fused multiply-subtract: a × b - c
352+
*
353+
* @param a First operand
354+
* @param b Second operand
355+
* @param c Third operand
356+
* @return a × b - c in GF16
357+
*/
358+
gf16_t gf16_fms(gf16_t a, gf16_t b, gf16_t c);
359+
360+
/**
361+
* Fused negated multiply-add: c - a × b
362+
*
363+
* @param a First operand
364+
* @param b Second operand
365+
* @param c Third operand
366+
* @return c - a × b in GF16
367+
*/
368+
gf16_t gf16_fnma(gf16_t a, gf16_t b, gf16_t c);
369+
370+
/**
371+
* φ-optimized FMA with golden-ratio scaling
372+
*
373+
* Applies φ-weighted dequantization in the fused operation
374+
* for better distribution in ML weight space.
375+
*
376+
* @param a First operand
377+
* @param b Second operand
378+
* @param c Third operand
379+
* @return φ-weighted fused result in GF16
380+
*/
381+
gf16_t gf16_phi_fma(gf16_t a, gf16_t b, gf16_t c);
382+
350383
/*======================================================================
351384
* Constants
352385
*======================================================================*/

src/c_abi.zig

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -201,11 +201,31 @@ export fn gf16_max(a: gf16_t, b: gf16_t) callconv(.c) gf16_t {
201201
}
202202

203203
export fn gf16_fma(a: gf16_t, b: gf16_t, c: gf16_t) callconv(.c) gf16_t {
204-
// Compute a * b + c in f32, then round to GF16
205-
const fa = rawToGf16(a).toF32();
206-
const fb = rawToGf16(b).toF32();
207-
const fc = rawToGf16(c).toF32();
208-
return gf16ToRaw(golden.GF16.fromF32(fa * fb + fc));
204+
const gf_a = rawToGf16(a);
205+
const gf_b = rawToGf16(b);
206+
const gf_c = rawToGf16(c);
207+
return gf16ToRaw(golden.GF16.fma(gf_a, gf_b, gf_c));
208+
}
209+
210+
export fn gf16_fms(a: gf16_t, b: gf16_t, c: gf16_t) callconv(.c) gf16_t {
211+
const gf_a = rawToGf16(a);
212+
const gf_b = rawToGf16(b);
213+
const gf_c = rawToGf16(c);
214+
return gf16ToRaw(golden.GF16.fms(gf_a, gf_b, gf_c));
215+
}
216+
217+
export fn gf16_fnma(a: gf16_t, b: gf16_t, c: gf16_t) callconv(.c) gf16_t {
218+
const gf_a = rawToGf16(a);
219+
const gf_b = rawToGf16(b);
220+
const gf_c = rawToGf16(c);
221+
return gf16ToRaw(golden.GF16.fnma(gf_a, gf_b, gf_c));
222+
}
223+
224+
export fn gf16_phi_fma(a: gf16_t, b: gf16_t, c: gf16_t) callconv(.c) gf16_t {
225+
const gf_a = rawToGf16(a);
226+
const gf_b = rawToGf16(b);
227+
const gf_c = rawToGf16(c);
228+
return gf16ToRaw(golden.GF16.phiFma(gf_a, gf_b, gf_c));
209229
}
210230

211231
// ═══════════════════════════════════════════════════════════════════

src/formats/golden_float16.zig

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,39 @@ pub const GF16 = packed struct(u16) {
190190
pub fn phiDequantize(gf: GF16) f32 {
191191
return gf.toF32() * PHI_SQ;
192192
}
193+
194+
/// Fused multiply-add: a * b + c (single rounding)
195+
pub fn fma(a: GF16, b: GF16, c: GF16) GF16 {
196+
return fromF32(a.toF32() * b.toF32() + c.toF32());
197+
}
198+
199+
/// Fused multiply-subtract: a * b - c (single rounding)
200+
pub fn fms(a: GF16, b: GF16, c: GF16) GF16 {
201+
return fromF32(a.toF32() * b.toF32() - c.toF32());
202+
}
203+
204+
/// Fused negated multiply-add: -(a * b) + c = c - a*b (single rounding)
205+
pub fn fnma(a: GF16, b: GF16, c: GF16) GF16 {
206+
return fromF32(c.toF32() - a.toF32() * b.toF32());
207+
}
208+
209+
/// φ-optimized FMA: phiDequantize(a) * phiDequantize(b) + phiDequantize(c)
210+
/// with φ-weighted scaling applied in single rounding step
211+
pub fn phiFma(a: GF16, b: GF16, c: GF16) GF16 {
212+
const scale = PHI_SQ * PHI_SQ;
213+
const sum = a.toF32() * b.toF32() * scale + c.toF32() * PHI_SQ;
214+
return fromF32(sum / PHI_SQ);
215+
}
216+
217+
/// φ-optimized dot product over slices (FMA accumulator)
218+
pub fn phiDot(a: []const GF16, b: []const GF16) GF16 {
219+
std.debug.assert(a.len == b.len);
220+
var acc: f32 = 0.0;
221+
for (a, b) |ai, bi| {
222+
acc += ai.toF32() * bi.toF32();
223+
}
224+
return fromF32(acc * PHI_INV_SQ);
225+
}
193226
};
194227

195228
// ═════════════════════════════════════════════════════════════════════════════
@@ -370,6 +403,51 @@ test "GF16 arithmetic" {
370403
try std.testing.expectApproxEqAbs(@as(f32, 0.6), quot.toF32(), 0.05);
371404
}
372405

406+
test "GF16 FMA" {
407+
const a = GF16.fromF32(2.0);
408+
const b = GF16.fromF32(3.0);
409+
const c = GF16.fromF32(1.0);
410+
const result = GF16.fma(a, b, c);
411+
try std.testing.expectApproxEqAbs(@as(f32, 7.0), result.toF32(), 0.1);
412+
}
413+
414+
test "GF16 FMS" {
415+
const a = GF16.fromF32(2.0);
416+
const b = GF16.fromF32(3.0);
417+
const c = GF16.fromF32(1.0);
418+
const result = GF16.fms(a, b, c);
419+
try std.testing.expectApproxEqAbs(@as(f32, 5.0), result.toF32(), 0.1);
420+
}
421+
422+
test "GF16 FNMA" {
423+
const a = GF16.fromF32(2.0);
424+
const b = GF16.fromF32(3.0);
425+
const c = GF16.fromF32(10.0);
426+
const result = GF16.fnma(a, b, c);
427+
try std.testing.expectApproxEqAbs(@as(f32, 4.0), result.toF32(), 0.1);
428+
}
429+
430+
test "GF16 phiFMA" {
431+
const a = GF16.fromF32(1.0);
432+
const b = GF16.fromF32(1.0);
433+
const c = GF16.fromF32(0.0);
434+
const result = GF16.phiFma(a, b, c);
435+
const deq = result.phiDequantize();
436+
try std.testing.expect(deq > 1.0 and deq < 10.0);
437+
}
438+
439+
test "GF16 phiDot product" {
440+
const a_vals = [_]f32{ 1.0, 2.0, 3.0 };
441+
const b_vals = [_]f32{ 1.0, 1.0, 1.0 };
442+
var a_gf: [3]GF16 = undefined;
443+
var b_gf: [3]GF16 = undefined;
444+
for (&a_gf, a_vals) |*g, v| g.* = GF16.fromF32(v);
445+
for (&b_gf, b_vals) |*g, v| g.* = GF16.fromF32(v);
446+
const result = GF16.phiDot(&a_gf, &b_gf);
447+
const back = result.phiDequantize();
448+
try std.testing.expect(back > 4.0 and back < 8.0);
449+
}
450+
373451
test "GF16 phi quantization roundtrip" {
374452
const original = 2.71828;
375453
const quantized = GF16.phiQuantize(original);

0 commit comments

Comments
 (0)