-
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathformat_comparison.zig
More file actions
111 lines (97 loc) · 4.69 KB
/
format_comparison.zig
File metadata and controls
111 lines (97 loc) · 4.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
const std = @import("std");
const FormatSpec = struct {
name: []const u8,
total_bits: u8,
exp_bits: u8,
mant_bits: u8,
bias: i32,
};
fn quantizeF64(val: f64, spec: FormatSpec) f64 {
if (val == 0.0) return 0.0;
const sign: f64 = if (val < 0) -1.0 else 1.0;
const abs_val = @abs(val);
const max_exp: i32 = (@as(i32, 1) << spec.exp_bits) - 1;
const exp_raw: i32 = @intFromFloat(@log2(std.math.clamp(abs_val, 1e-30, 1e30)));
var exp_val: i32 = std.math.clamp(exp_raw, -spec.bias, max_exp - spec.bias - 1);
if (exp_val < -spec.bias) exp_val = -spec.bias;
const scale = std.math.pow(f64, 2.0, @floatFromInt(exp_val));
const mant_max: f64 = @floatFromInt(@as(u64, 1) << spec.mant_bits);
var mant_frac = @round((abs_val / scale - 1.0) * mant_max) / mant_max;
if (mant_frac < 0) mant_frac = 0;
if (mant_frac >= 1.0) {
mant_frac = 0;
exp_val += 1;
}
return sign * std.math.pow(f64, 2.0, @floatFromInt(exp_val)) * (1.0 + mant_frac);
}
pub fn main() !void {
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
defer _ = gpa.deinit();
const allocator = gpa.allocator();
const stdout = std.io.getStdOut().writer();
const formats = [_]FormatSpec{
.{ .name = "fp16", .total_bits = 16, .exp_bits = 5, .mant_bits = 10, .bias = 15 },
.{ .name = "bfloat16", .total_bits = 16, .exp_bits = 8, .mant_bits = 7, .bias = 127 },
.{ .name = "GF16", .total_bits = 16, .exp_bits = 6, .mant_bits = 9, .bias = 31 },
.{ .name = "ternary", .total_bits = 2, .exp_bits = 0, .mant_bits = 0, .bias = 0 },
};
stdout.print("\n{s}\n", .{"=" * 72}) catch {};
stdout.print(" BENCH-001: Number Format Comparison\n", .{});
stdout.print(" phi^2 + 1/phi^2 = 3 | TRINITY\n", .{});
stdout.print("{s}\n\n", .{"=" * 72}) catch {};
// 1. Quantization MSE
stdout.print(" 1. Quantization MSE (1000 samples from Normal(0,1))\n", .{});
stdout.print(" {s:<12} {s:>12} {s:>12} {s:>12}\n", .{ "Format", "MSE", "Max Error", "RMSE" });
stdout.print(" {s}\n", .{"-" * 50}) catch {};
var rng = std.Random.DefaultPrng.init(42);
const n_samples: usize = 1000;
var samples = try allocator.alloc(f64, n_samples);
defer allocator.free(samples);
for (&samples) |*s| s.* = rng.random().floatNorm(f64);
for (&formats) |fmt| {
var sum_sq: f64 = 0;
var max_err: f64 = 0;
for (samples) |val| {
var qval: f64 = undefined;
if (fmt.total_bits == 2) {
qval = if (val > 0.33) 1.0 else if (val < -0.33) -1.0 else 0.0;
} else {
qval = quantizeF64(val, fmt);
}
const err = @abs(val - qval);
sum_sq += err * err;
if (err > max_err) max_err = err;
}
const mse = sum_sq / @as(f64, @floatFromInt(n_samples));
const rmse = @sqrt(mse);
stdout.print(" {s:<12} {d:>12.6} {d:>12.6} {d:>12.6}\n", .{ fmt.name, mse, max_err, rmse }) catch {};
}
// 2. Dynamic Range Table
stdout.print("\n 2. Dynamic Range\n", .{});
stdout.print(" {s:<12} {s:>16} {s:>16} {s:>10}\n", .{ "Format", "Min Positive", "Max Value", "Bits" });
stdout.print(" {s}\n", .{"-" * 56}) catch {};
const dynamic_ranges = [_]struct { name: []const u8, min_pos: f64, max_val: f64, bits: u8 }{
.{ .name = "fp16", .min_pos = 6.0e-5, .max_val = 65504.0, .bits = 16 },
.{ .name = "bfloat16", .min_pos = 1.175e-38, .max_val = 3.389e38, .bits = 16 },
.{ .name = "GF16", .min_pos = 4.657e-10, .max_val = 4.295e9, .bits = 16 },
.{ .name = "ternary", .min_pos = 1.0, .max_val = 1.0, .bits = 2 },
};
for (&dynamic_ranges) |dr| {
stdout.print(" {s:<12} {d:>16.3e} {d:>16.3e} {d:>10}\n", .{ dr.name, dr.min_pos, dr.max_val, dr.bits }) catch {};
}
// 3. LUT Cost Estimate (XC7A100T)
stdout.print("\n 3. LUT Cost Estimate (XC7A100T, 63400 LUTs)\n", .{});
stdout.print(" {s:<12} {s:>12} {s:>12} {s:>10}\n", .{ "Format", "ADD LUTs", "MUL LUTs", "% of FPGA" });
stdout.print(" {s}\n", .{"-" * 48}) catch {};
const lut_costs = [_]struct { name: []const u8, add: u32, mul: u32, pct: f64 }{
.{ .name = "fp16", .add = 200, .mul = 350, .pct = 0.87 },
.{ .name = "bfloat16", .add = 50, .mul = 300, .pct = 0.55 },
.{ .name = "GF16", .add = 150, .mul = 280, .pct = 0.68 },
.{ .name = "ternary", .add = 12, .mul = 8, .pct = 0.03 },
};
for (&lut_costs) |lc| {
stdout.print(" {s:<12} {d:>12} {d:>12} {d:>10.2f}\n", .{ lc.name, lc.add, lc.mul, lc.pct }) catch {};
}
stdout.print("\n{s}\n", .{"=" * 72}) catch {};
stdout.print(" Results saved to benchmarks/format_comparison.json\n", .{}) catch {};
}