|
| 1 | +const std = @import("std"); |
| 2 | + |
| 3 | +pub const EntropyGate = struct { |
| 4 | + id: []const u8, |
| 5 | + description: []const u8, |
| 6 | + threshold: f32, |
| 7 | + passed: bool, |
| 8 | + value: f32, |
| 9 | +}; |
| 10 | + |
| 11 | +pub const SweepResult = struct { |
| 12 | + gate_id: []const u8, |
| 13 | + entropy: f32, |
| 14 | + rank: usize, |
| 15 | + passed: bool, |
| 16 | +}; |
| 17 | + |
| 18 | +pub const SweepConfig = struct { |
| 19 | + num_samples: usize = 10000, |
| 20 | + vocab_size: usize = 729, |
| 21 | + grid_size: usize = 9, |
| 22 | + num_states: usize = 9, |
| 23 | + top_k: usize = 3, |
| 24 | +}; |
| 25 | + |
| 26 | +pub fn computeShannonEntropy(distribution: []const f32) f32 { |
| 27 | + var entropy: f32 = 0; |
| 28 | + for (distribution) |p| { |
| 29 | + if (p > 1e-10) { |
| 30 | + entropy -= p * std.math.log2(p); |
| 31 | + } |
| 32 | + } |
| 33 | + return entropy; |
| 34 | +} |
| 35 | + |
| 36 | +pub fn computeEntropyFromCounts(counts: []const u32) f32 { |
| 37 | + var total: f32 = 0; |
| 38 | + for (counts) |c| total += @as(f32, @floatFromInt(c)); |
| 39 | + if (total < 1e-10) return 0; |
| 40 | + |
| 41 | + var entropy: f32 = 0; |
| 42 | + for (counts) |c| { |
| 43 | + const p = @as(f32, @floatFromInt(c)) / total; |
| 44 | + if (p > 1e-10) { |
| 45 | + entropy -= p * std.math.log2(p); |
| 46 | + } |
| 47 | + } |
| 48 | + return entropy; |
| 49 | +} |
| 50 | + |
| 51 | +pub const EntropySweeper = struct { |
| 52 | + allocator: std.mem.Allocator, |
| 53 | + config: SweepConfig, |
| 54 | + rng: std.Random.DefaultPrng, |
| 55 | + |
| 56 | + pub fn init(allocator: std.mem.Allocator, config: SweepConfig) EntropySweeper { |
| 57 | + return .{ |
| 58 | + .allocator = allocator, |
| 59 | + .config = config, |
| 60 | + .rng = std.Random.DefaultPrng.init(42), |
| 61 | + }; |
| 62 | + } |
| 63 | + |
| 64 | + pub fn sweepGate(self: *EntropySweeper, gate_id: []const u8, threshold: f32) !SweepResult { |
| 65 | + const counts = try self.allocator.alloc(u32, self.config.vocab_size); |
| 66 | + defer self.allocator.free(counts); |
| 67 | + @memset(counts, 0); |
| 68 | + |
| 69 | + const random = self.rng.random(); |
| 70 | + for (0..self.config.num_samples) |_| { |
| 71 | + const idx = random.intRangeLessThan(usize, 0, self.config.vocab_size); |
| 72 | + counts[idx] += 1; |
| 73 | + } |
| 74 | + |
| 75 | + const entropy = computeEntropyFromCounts(counts); |
| 76 | + const passed = entropy >= threshold; |
| 77 | + |
| 78 | + return .{ |
| 79 | + .gate_id = gate_id, |
| 80 | + .entropy = entropy, |
| 81 | + .rank = 0, |
| 82 | + .passed = passed, |
| 83 | + }; |
| 84 | + } |
| 85 | + |
| 86 | + pub fn sweepAll(self: *EntropySweeper, gates: []const EntropyGate) ![]SweepResult { |
| 87 | + var results = try self.allocator.alloc(SweepResult, gates.len); |
| 88 | + |
| 89 | + for (gates, results) |gate, *result| { |
| 90 | + result.* = try self.sweepGate(gate.id, gate.threshold); |
| 91 | + } |
| 92 | + |
| 93 | + for (results, 0..) |r, i| { |
| 94 | + var rank: usize = 1; |
| 95 | + for (results) |other| { |
| 96 | + if (other.entropy > r.entropy) rank += 1; |
| 97 | + } |
| 98 | + results[i].rank = rank; |
| 99 | + } |
| 100 | + |
| 101 | + return results; |
| 102 | + } |
| 103 | + |
| 104 | + pub fn topK(self: *EntropySweeper, results: []SweepResult, k: usize) []SweepResult { |
| 105 | + var sorted = self.allocator.dupe(SweepResult, results) catch return results[0..@min(k, results.len)]; |
| 106 | + std.mem.sort(SweepResult, sorted, {}, struct { |
| 107 | + pub fn lessThan(_: void, a: SweepResult, b: SweepResult) bool { |
| 108 | + return a.entropy > b.entropy; |
| 109 | + } |
| 110 | + }.lessThan); |
| 111 | + return sorted[0..@min(k, sorted.len)]; |
| 112 | + } |
| 113 | +}; |
| 114 | + |
| 115 | +pub fn printSweepReport(results: []const SweepResult, top_results: []const SweepResult, writer: anytype) !void { |
| 116 | + try writer.print("\n G1-G8 Entropy Sweep Results\n", .{}); |
| 117 | + try writer.print(" {s}\n", .{"-" * 60}); |
| 118 | + try writer.print(" {s:<8} {s:>12} {s:>6} {s:>8}\n", .{ "Gate", "Entropy", "Rank", "Passed" }); |
| 119 | + try writer.print(" {s}\n", .{"-" * 60}); |
| 120 | + for (results) |r| { |
| 121 | + const status = if (r.passed) "PASS" else "FAIL"; |
| 122 | + try writer.print(" {s:<8} {d:>12.4} {d:>6} {s:>8}\n", .{ r.gate_id, r.entropy, r.rank, status }); |
| 123 | + } |
| 124 | + try writer.print(" {s}\n", .{"-" * 60}); |
| 125 | + try writer.print(" Top-{d}:\n", .{top_results.len}); |
| 126 | + for (top_results) |r| { |
| 127 | + try writer.print(" {s}: {d:.4}\n", .{ r.gate_id, r.entropy }); |
| 128 | + } |
| 129 | + try writer.print("\n", .{}); |
| 130 | +} |
| 131 | + |
| 132 | +test "Shannon entropy uniform distribution" { |
| 133 | + const dist = [_]f32{ 0.25, 0.25, 0.25, 0.25 }; |
| 134 | + const h = computeShannonEntropy(&dist); |
| 135 | + try std.testing.expectApproxEqAbs(@as(f32, 2.0), h, 0.01); |
| 136 | +} |
| 137 | + |
| 138 | +test "Shannon entropy single value" { |
| 139 | + const dist = [_]f32{ 1.0, 0.0, 0.0 }; |
| 140 | + const h = computeShannonEntropy(&dist); |
| 141 | + try std.testing.expectApproxEqAbs(@as(f32, 0.0), h, 0.01); |
| 142 | +} |
| 143 | + |
| 144 | +test "entropy from counts" { |
| 145 | + const counts = [_]u32{ 250, 250, 250, 250 }; |
| 146 | + const h = computeEntropyFromCounts(&counts); |
| 147 | + try std.testing.expectApproxEqAbs(@as(f32, 2.0), h, 0.01); |
| 148 | +} |
| 149 | + |
| 150 | +test "sweep gate produces valid result" { |
| 151 | + const allocator = std.testing.allocator; |
| 152 | + var sweeper = EntropySweeper.init(allocator, .{ .num_samples = 1000, .vocab_size = 27 }); |
| 153 | + |
| 154 | + const result = try sweeper.sweepGate("G1", 2.0); |
| 155 | + try std.testing.expect(result.entropy > 0); |
| 156 | + try std.testing.expect(result.entropy < 10.0); |
| 157 | +} |
| 158 | + |
| 159 | +test "sweep all gates" { |
| 160 | + const allocator = std.testing.allocator; |
| 161 | + var sweeper = EntropySweeper.init(allocator, .{ .num_samples = 500, .vocab_size = 27 }); |
| 162 | + |
| 163 | + const gates = [_]EntropyGate{ |
| 164 | + .{ .id = "G1", .description = "token", .threshold = 3.0, .passed = false, .value = 0 }, |
| 165 | + .{ .id = "G2", .description = "attention", .threshold = 2.0, .passed = false, .value = 0 }, |
| 166 | + .{ .id = "G3", .description = "weight", .threshold = 1.0, .passed = false, .value = 0 }, |
| 167 | + }; |
| 168 | + |
| 169 | + const results = try sweeper.sweepAll(&gates); |
| 170 | + defer allocator.free(results); |
| 171 | + |
| 172 | + try std.testing.expectEqual(@as(usize, 3), results.len); |
| 173 | + for (results) |r| { |
| 174 | + try std.testing.expect(r.rank >= 1); |
| 175 | + } |
| 176 | +} |
0 commit comments