|
| 1 | +const std = @import("std"); |
| 2 | + |
| 3 | +pub const DatasetType = enum { |
| 4 | + code_completion, |
| 5 | + medical_notes, |
| 6 | + scientific_papers, |
| 7 | + synthetic, |
| 8 | +}; |
| 9 | + |
| 10 | +pub const DatasetSpec = struct { |
| 11 | + name: []const u8, |
| 12 | + dataset_type: DatasetType, |
| 13 | + num_samples: usize, |
| 14 | + avg_tokens_per_sample: usize, |
| 15 | + vocab_coverage: f32, |
| 16 | +}; |
| 17 | + |
| 18 | +pub const BenchmarkMetrics = struct { |
| 19 | + ppl: f32, |
| 20 | + accuracy: f32, |
| 21 | + training_epochs_per_hour: f32, |
| 22 | + inference_tok_per_sec: f32, |
| 23 | + model_size_mb: f32, |
| 24 | +}; |
| 25 | + |
| 26 | +pub const FormatComparison = struct { |
| 27 | + format_name: []const u8, |
| 28 | + bits_per_weight: f32, |
| 29 | + metrics: BenchmarkMetrics, |
| 30 | + gap_vs_fp32_ppl: f32, |
| 31 | +}; |
| 32 | + |
| 33 | +pub const DomainBenchmark = struct { |
| 34 | + allocator: std.mem.Allocator, |
| 35 | + dataset: DatasetSpec, |
| 36 | + comparisons: std.ArrayList(FormatComparison), |
| 37 | + fp32_baseline: ?BenchmarkMetrics, |
| 38 | + |
| 39 | + pub fn init(allocator: std.mem.Allocator, dataset: DatasetSpec) DomainBenchmark { |
| 40 | + return .{ |
| 41 | + .allocator = allocator, |
| 42 | + .dataset = dataset, |
| 43 | + .comparisons = std.ArrayList(FormatComparison).init(allocator), |
| 44 | + .fp32_baseline = null, |
| 45 | + }; |
| 46 | + } |
| 47 | + |
| 48 | + pub fn deinit(self: *DomainBenchmark) void { |
| 49 | + self.comparisons.deinit(); |
| 50 | + } |
| 51 | + |
| 52 | + pub fn setBaseline(self: *DomainBenchmark, metrics: BenchmarkMetrics) void { |
| 53 | + self.fp32_baseline = metrics; |
| 54 | + } |
| 55 | + |
| 56 | + pub fn addComparison(self: *DomainBenchmark, comp: FormatComparison) !void { |
| 57 | + if (self.fp32_baseline) |baseline| { |
| 58 | + var mut_comp = comp; |
| 59 | + mut_comp.gap_vs_fp32_ppl = (comp.metrics.ppl - baseline.ppl) / baseline.ppl * 100.0; |
| 60 | + try self.comparisons.append(mut_comp); |
| 61 | + } else { |
| 62 | + try self.comparisons.append(comp); |
| 63 | + } |
| 64 | + } |
| 65 | + |
| 66 | + pub fn passesThreshold(self: *const DomainBenchmark, max_ppl_gap_pct: f32) bool { |
| 67 | + for (self.comparisons.items) |comp| { |
| 68 | + if (comp.gap_vs_fp32_ppl > max_ppl_gap_pct) return false; |
| 69 | + } |
| 70 | + return true; |
| 71 | + } |
| 72 | + |
| 73 | + pub fn printReport(self: *const DomainBenchmark, writer: anytype) !void { |
| 74 | + try writer.print("\n Domain Benchmark: {s} ({s})\n", .{ self.dataset.name, @tagName(self.dataset.dataset_type) }); |
| 75 | + try writer.print(" {s}\n", .{"-" * 72}); |
| 76 | + try writer.print(" Samples: {d} | Avg tokens: {d} | Vocab coverage: {d:.0}%\n\n", .{ |
| 77 | + self.dataset.num_samples, |
| 78 | + self.dataset.avg_tokens_per_sample, |
| 79 | + self.dataset.vocab_coverage * 100, |
| 80 | + }); |
| 81 | + |
| 82 | + if (self.fp32_baseline) |bl| { |
| 83 | + try writer.print(" FP32 Baseline: PPL={d:.2}, tok/s={d:.0}, size={d:.1}MB\n\n", .{ |
| 84 | + bl.ppl, |
| 85 | + bl.inference_tok_per_sec, |
| 86 | + bl.model_size_mb, |
| 87 | + }); |
| 88 | + } |
| 89 | + |
| 90 | + try writer.print(" {s:<12} {s:>6} {s:>8} {s:>8} {s:>12} {s:>8}\n", .{ |
| 91 | + "Format", "Bits", "PPL", "Acc%", "Tok/s", "PPL Gap" }); |
| 92 | + try writer.print(" {s}\n", .{"-" * 72}); |
| 93 | + |
| 94 | + for (self.comparisons.items) |comp| { |
| 95 | + try writer.print(" {s:<12} {d:>5.1} {d:>8.2} {d:>7.1}% {d:>11.0} {d:>7.1}%\n", .{ |
| 96 | + comp.format_name, |
| 97 | + comp.bits_per_weight, |
| 98 | + comp.metrics.ppl, |
| 99 | + comp.metrics.accuracy * 100, |
| 100 | + comp.metrics.inference_tok_per_sec, |
| 101 | + comp.gap_vs_fp32_ppl, |
| 102 | + }); |
| 103 | + } |
| 104 | + try writer.print(" {s}\n\n", .{"-" * 72}); |
| 105 | + } |
| 106 | +}; |
| 107 | + |
| 108 | +pub const BenchmarkSuite = struct { |
| 109 | + allocator: std.mem.Allocator, |
| 110 | + benchmarks: std.ArrayList(DomainBenchmark), |
| 111 | + |
| 112 | + pub fn init(allocator: std.mem.Allocator) BenchmarkSuite { |
| 113 | + return .{ |
| 114 | + .allocator = allocator, |
| 115 | + .benchmarks = std.ArrayList(DomainBenchmark).init(allocator), |
| 116 | + }; |
| 117 | + } |
| 118 | + |
| 119 | + pub fn deinit(self: *BenchmarkSuite) void { |
| 120 | + for (self.benchmarks.items) |*b| b.deinit(); |
| 121 | + self.benchmarks.deinit(); |
| 122 | + } |
| 123 | + |
| 124 | + pub fn addDataset(self: *BenchmarkSuite, spec: DatasetSpec) !*DomainBenchmark { |
| 125 | + try self.benchmarks.append(DomainBenchmark.init(self.allocator, spec)); |
| 126 | + return &self.benchmarks.items[self.benchmarks.items.len - 1]; |
| 127 | + } |
| 128 | + |
| 129 | + pub fn allPassThreshold(self: *const BenchmarkSuite, max_ppl_gap_pct: f32) bool { |
| 130 | + for (self.benchmarks.items) |b| { |
| 131 | + if (!b.passesThreshold(max_ppl_gap_pct)) return false; |
| 132 | + } |
| 133 | + return true; |
| 134 | + } |
| 135 | + |
| 136 | + pub fn overallSummary(self: *const BenchmarkSuite) struct { avg_ppl_gap: f32, datasets: usize, passing: usize } { |
| 137 | + var total_gap: f32 = 0; |
| 138 | + var count: usize = 0; |
| 139 | + var passing: usize = 0; |
| 140 | + for (self.benchmarks.items) |b| { |
| 141 | + for (b.comparisons.items) |c| { |
| 142 | + total_gap += c.gap_vs_fp32_ppl; |
| 143 | + count += 1; |
| 144 | + if (c.gap_vs_fp32_ppl <= 10.0) passing += 1; |
| 145 | + } |
| 146 | + } |
| 147 | + return .{ |
| 148 | + .avg_ppl_gap = if (count > 0) total_gap / @as(f32, @floatFromInt(count)) else 0, |
| 149 | + .datasets = self.benchmarks.items.len, |
| 150 | + .passing = passing, |
| 151 | + }; |
| 152 | + } |
| 153 | +}; |
| 154 | + |
| 155 | +test "domain benchmark with baseline" { |
| 156 | + const allocator = std.testing.allocator; |
| 157 | + var bench = DomainBenchmark.init(allocator, .{ |
| 158 | + .name = "ArXiv Abstracts", |
| 159 | + .dataset_type = .scientific_papers, |
| 160 | + .num_samples = 50000, |
| 161 | + .avg_tokens_per_sample = 150, |
| 162 | + .vocab_coverage = 0.85, |
| 163 | + }); |
| 164 | + defer bench.deinit(); |
| 165 | + |
| 166 | + bench.setBaseline(.{ |
| 167 | + .ppl = 45.0, |
| 168 | + .accuracy = 0.0, |
| 169 | + .training_epochs_per_hour = 12.0, |
| 170 | + .inference_tok_per_sec = 15000, |
| 171 | + .model_size_mb = 10.8, |
| 172 | + }); |
| 173 | + |
| 174 | + try bench.addComparison(.{ |
| 175 | + .format_name = "Ternary", |
| 176 | + .bits_per_weight = 2, |
| 177 | + .metrics = .{ |
| 178 | + .ppl = 49.5, |
| 179 | + .accuracy = 0.0, |
| 180 | + .training_epochs_per_hour = 18.0, |
| 181 | + .inference_tok_per_sec = 45000, |
| 182 | + .model_size_mb = 2.7, |
| 183 | + }, |
| 184 | + .gap_vs_fp32_ppl = 0, |
| 185 | + }); |
| 186 | + |
| 187 | + try std.testing.expect(bench.comparisons.items.len == 1); |
| 188 | + try std.testing.expect(bench.comparisons.items[0].gap_vs_fp32_ppl > 0); |
| 189 | +} |
| 190 | + |
| 191 | +test "benchmark suite multi-dataset" { |
| 192 | + const allocator = std.testing.allocator; |
| 193 | + var suite = BenchmarkSuite.init(allocator); |
| 194 | + defer suite.deinit(); |
| 195 | + |
| 196 | + const b1 = try suite.addDataset(.{ |
| 197 | + .name = "GitHub Code", |
| 198 | + .dataset_type = .code_completion, |
| 199 | + .num_samples = 100000, |
| 200 | + .avg_tokens_per_sample = 200, |
| 201 | + .vocab_coverage = 0.92, |
| 202 | + }); |
| 203 | + |
| 204 | + const b2 = try suite.addDataset(.{ |
| 205 | + .name = "ArXiv", |
| 206 | + .dataset_type = .scientific_papers, |
| 207 | + .num_samples = 50000, |
| 208 | + .avg_tokens_per_sample = 150, |
| 209 | + .vocab_coverage = 0.85, |
| 210 | + }); |
| 211 | + |
| 212 | + b1.setBaseline(.{ .ppl = 30, .accuracy = 0.0, .training_epochs_per_hour = 12, .inference_tok_per_sec = 15000, .model_size_mb = 10.8 }); |
| 213 | + b2.setBaseline(.{ .ppl = 45, .accuracy = 0.0, .training_epochs_per_hour = 10, .inference_tok_per_sec = 14000, .model_size_mb = 10.8 }); |
| 214 | + |
| 215 | + try b1.addComparison(.{ .format_name = "GF16", .bits_per_weight = 16, .metrics = .{ .ppl = 30.5, .accuracy = 0.0, .training_epochs_per_hour = 14, .inference_tok_per_sec = 20000, .model_size_mb = 5.4 }, .gap_vs_fp32_ppl = 0 }); |
| 216 | + try b2.addComparison(.{ .format_name = "GF16", .bits_per_weight = 16, .metrics = .{ .ppl = 46.0, .accuracy = 0.0, .training_epochs_per_hour = 12, .inference_tok_per_sec = 19000, .model_size_mb = 5.4 }, .gap_vs_fp32_ppl = 0 }); |
| 217 | + |
| 218 | + const summary = suite.overallSummary(); |
| 219 | + try std.testing.expectEqual(@as(usize, 2), summary.datasets); |
| 220 | + try std.testing.expect(summary.avg_ppl_gap < 10.0); |
| 221 | +} |
| 222 | + |
| 223 | +test "pass threshold check" { |
| 224 | + const allocator = std.testing.allocator; |
| 225 | + var bench = DomainBenchmark.init(allocator, .{ |
| 226 | + .name = "test", |
| 227 | + .dataset_type = .synthetic, |
| 228 | + .num_samples = 100, |
| 229 | + .avg_tokens_per_sample = 50, |
| 230 | + .vocab_coverage = 0.5, |
| 231 | + }); |
| 232 | + defer bench.deinit(); |
| 233 | + |
| 234 | + bench.setBaseline(.{ .ppl = 100, .accuracy = 0.0, .training_epochs_per_hour = 1, .inference_tok_per_sec = 1, .model_size_mb = 1 }); |
| 235 | + try bench.addComparison(.{ .format_name = "Ternary", .bits_per_weight = 2, .metrics = .{ .ppl = 108, .accuracy = 0.0, .training_epochs_per_hour = 2, .inference_tok_per_sec = 3, .model_size_mb = 0.25 }, .gap_vs_fp32_ppl = 0 }); |
| 236 | + |
| 237 | + try std.testing.expect(bench.passesThreshold(10.0)); |
| 238 | + try std.testing.expect(!bench.passesThreshold(5.0)); |
| 239 | +} |
0 commit comments