|
| 1 | +// HSLM — Trit-wise Attention Weights (Session 35 Quick Win #4) |
| 2 | +// Replace float32 attention weights with ternary {-1,0,+1} + per-position scales |
| 3 | +// Expected: 3× memory reduction with ~2% PPL impact |
| 4 | +// |
| 5 | +// φ² + 1/φ² = 3 = TRINITY |
| 6 | + |
| 7 | +const std = @import("std"); |
| 8 | +const math = std.math; |
| 9 | +const constants = @import("constants.zig"); |
| 10 | + |
| 11 | +const EMBED_DIM = constants.EMBED_DIM; // 243 |
| 12 | +const NUM_HEADS = constants.NUM_HEADS; // 3 |
| 13 | +const HEAD_DIM = constants.HEAD_DIM; // 81 |
| 14 | +const CONTEXT_LEN = constants.CONTEXT_LEN; // 81 |
| 15 | + |
| 16 | +const PHI_INV: f32 = 0.618033988749895; // φ⁻¹ |
| 17 | +const SACRED_GAMMA: f64 = constants.SACRED_GAMMA; // φ⁻³ ≈ 0.236 |
| 18 | + |
| 19 | +// ═══════════════════════════════════════════════════════════════════════════════ |
| 20 | +// TRIT-WISE ATTENTION WEIGHTS |
| 21 | +// ═══════════════════════════════════════════════════════════════════════════════ |
| 22 | + |
| 23 | +/// Trit-wise attention weights with per-position scale factors |
| 24 | +/// Memory layout: weights (ternary) + scales (float32) |
| 25 | +/// Original: [NUM_HEADS × CONTEXT_LEN] f32 = 3 × 81 × 4 = 972 bytes |
| 26 | +/// Optimized: [NUM_HEADS × CONTEXT_LEN] i8 + [NUM_HEADS × CONTEXT_LEN] f32 = 243 + 972 = 1215 bytes |
| 27 | +/// Wait, that's not right. Let me recalculate: |
| 28 | +/// Original: 3 × 81 × 4 = 972 bytes |
| 29 | +/// Optimized: 3 × 81 × 1 (ternary) + 3 × 81 × 4 (scales) = 243 + 972 = 1215 bytes |
| 30 | +/// |
| 31 | +/// Better approach: Store scales per-head only (not per-position) |
| 32 | +/// Optimized: [NUM_HEADS × CONTEXT_LEN] i8 + [NUM_HEADS] f32 = 243 + 12 = 255 bytes |
| 33 | +/// Memory reduction: 972 → 255 = 3.8× reduction! |
| 34 | +pub const TritAttentionWeights = struct { |
| 35 | + // Ternary weights: {-1, 0, +1} for each (head, position) pair |
| 36 | + weights: [NUM_HEADS * CONTEXT_LEN]i8, |
| 37 | + |
| 38 | + // Per-head scale factors (preserve magnitude information) |
| 39 | + // Computed as: scale_h = mean(|weights_h|) for head h |
| 40 | + scales: [NUM_HEADS]f32, |
| 41 | + |
| 42 | + // φ-threshold for quantization (default: φ⁻² = 0.382) |
| 43 | + quantization_threshold: f32 = 0.382, |
| 44 | + |
| 45 | + allocator: std.mem.Allocator, |
| 46 | + |
| 47 | + const Self = @This(); |
| 48 | + |
| 49 | + /// Initialize with zero weights and unit scales |
| 50 | + pub fn init(allocator: std.mem.Allocator) !Self { |
| 51 | + var weights: [NUM_HEADS * CONTEXT_LEN]i8 = undefined; |
| 52 | + @memset(&weights, 0); |
| 53 | + |
| 54 | + var scales: [NUM_HEADS]f32 = undefined; |
| 55 | + @memset(&scales, 1.0); |
| 56 | + |
| 57 | + return Self{ |
| 58 | + .weights = weights, |
| 59 | + .scales = scales, |
| 60 | + .quantization_threshold = 0.382, |
| 61 | + .allocator = allocator, |
| 62 | + }; |
| 63 | + } |
| 64 | + |
| 65 | + /// Quantize float32 attention scores to ternary weights |
| 66 | + /// Computes per-head scale factors to preserve magnitude information |
| 67 | + pub fn quantizeFromFloat(self: *TritAttentionWeights, float_weights: []const f32, num_heads: usize, seq_len: usize) void { |
| 68 | + std.debug.assert(float_weights.len == num_heads * seq_len); |
| 69 | + |
| 70 | + // Quantize to ternary and compute per-head scales |
| 71 | + for (0..num_heads) |h| { |
| 72 | + const head_offset = h * seq_len; |
| 73 | + |
| 74 | + // Step 1: Compute scale for this head (mean of absolute values) |
| 75 | + var abs_sum: f32 = 0.0; |
| 76 | + for (0..seq_len) |pos| { |
| 77 | + abs_sum += @abs(float_weights[head_offset + pos]); |
| 78 | + } |
| 79 | + self.scales[h] = if (abs_sum > 1e-6) |
| 80 | + @max(0.1, abs_sum / @as(f32, @floatFromInt(seq_len))) |
| 81 | + else |
| 82 | + 1.0; |
| 83 | + |
| 84 | + // Step 2: Quantize to ternary {-1, 0, +1} |
| 85 | + const scale_h = self.scales[h]; |
| 86 | + for (0..seq_len) |pos| { |
| 87 | + const val = float_weights[head_offset + pos]; |
| 88 | + const scaled = val / scale_h; |
| 89 | + |
| 90 | + // φ-adaptive threshold (slightly tighter than 0.5) |
| 91 | + const thr = self.quantization_threshold; |
| 92 | + |
| 93 | + self.weights[head_offset + pos] = if (scaled > thr) |
| 94 | + 1 |
| 95 | + else if (scaled < -thr) |
| 96 | + -1 |
| 97 | + else |
| 98 | + 0; |
| 99 | + } |
| 100 | + } |
| 101 | + } |
| 102 | + |
| 103 | + /// Reconstruct float weights from ternary + scales (for backward compatibility) |
| 104 | + pub fn reconstructToFloat(self: *const TritAttentionWeights, output: []f32, num_heads: usize, seq_len: usize) void { |
| 105 | + std.debug.assert(output.len == num_heads * seq_len); |
| 106 | + |
| 107 | + for (0..num_heads) |h| { |
| 108 | + const head_offset = h * seq_len; |
| 109 | + const scale_h = self.scales[h]; |
| 110 | + |
| 111 | + for (0..seq_len) |pos| { |
| 112 | + const trit = self.weights[head_offset + pos]; |
| 113 | + output[head_offset + pos] = @as(f32, @floatFromInt(trit)) * scale_h; |
| 114 | + } |
| 115 | + } |
| 116 | + } |
| 117 | + |
| 118 | + /// Compute per-head entropy (for analysis/debugging) |
| 119 | + pub fn headEntropy(self: *const TritAttentionWeights, head: usize) f32 { |
| 120 | + const head_offset = head * CONTEXT_LEN; |
| 121 | + |
| 122 | + var counts: [3]usize = .{ 0, 0, 0 }; // -1, 0, +1 |
| 123 | + for (0..CONTEXT_LEN) |pos| { |
| 124 | + const trit = self.weights[head_offset + pos]; |
| 125 | + // Map {-1, 0, +1} to {0, 1, 2} |
| 126 | + const idx: usize = if (trit < 0) 0 else if (trit > 0) 2 else 1; |
| 127 | + counts[idx] += 1; |
| 128 | + } |
| 129 | + |
| 130 | + const total: f32 = @floatFromInt(CONTEXT_LEN); |
| 131 | + var entropy: f32 = 0.0; |
| 132 | + for (counts) |count| { |
| 133 | + if (count > 0) { |
| 134 | + const p = @as(f32, @floatFromInt(count)) / total; |
| 135 | + if (p > 1e-6) { |
| 136 | + entropy -= p * @log(p); |
| 137 | + } |
| 138 | + } |
| 139 | + } |
| 140 | + |
| 141 | + return entropy; |
| 142 | + } |
| 143 | + |
| 144 | + /// Compute sparsity (fraction of zero weights) |
| 145 | + pub fn sparsity(self: *const TritAttentionWeights, head: usize) f32 { |
| 146 | + const head_offset = head * CONTEXT_LEN; |
| 147 | + var zero_count: usize = 0; |
| 148 | + |
| 149 | + for (0..CONTEXT_LEN) |pos| { |
| 150 | + if (self.weights[head_offset + pos] == 0) zero_count += 1; |
| 151 | + } |
| 152 | + |
| 153 | + return @as(f32, @floatFromInt(zero_count)) / @as(f32, @floatFromInt(CONTEXT_LEN)); |
| 154 | + } |
| 155 | +}; |
| 156 | + |
| 157 | +// ═══════════════════════════════════════════════════════════════════════════════ |
| 158 | +// TESTS |
| 159 | +// ═══════════════════════════════════════════════════════════════════════════════ |
| 160 | + |
| 161 | +test "trit attention: quantization preserves sparsity pattern" { |
| 162 | + const allocator = std.testing.allocator; |
| 163 | + var trit_attn = try TritAttentionWeights.init(allocator); |
| 164 | + |
| 165 | + // Create float weights with known pattern |
| 166 | + var float_weights: [3 * 10]f32 = undefined; |
| 167 | + { |
| 168 | + var i: usize = 0; |
| 169 | + // Head 0: strong positive values (will quantize to +1) |
| 170 | + for (0..10) |_| { |
| 171 | + float_weights[i] = 1.0; |
| 172 | + i += 1; |
| 173 | + } |
| 174 | + // Head 1: strong negative values (will quantize to -1) |
| 175 | + for (0..10) |_| { |
| 176 | + float_weights[i] = -1.0; |
| 177 | + i += 1; |
| 178 | + } |
| 179 | + // Head 2: weak values (will quantize to 0) |
| 180 | + for (0..10) |_| { |
| 181 | + float_weights[i] = 0.05; |
| 182 | + i += 1; |
| 183 | + } |
| 184 | + } |
| 185 | + |
| 186 | + trit_attn.quantizeFromFloat(&float_weights, 3, 10); |
| 187 | + |
| 188 | + // Head 2 should be highly sparse (weak values → zeros) |
| 189 | + const sparsity_h2 = trit_attn.sparsity(2); |
| 190 | + try std.testing.expect(sparsity_h2 > 0.5); // At least 50% sparse |
| 191 | + |
| 192 | + // Head 0 should be mostly non-zero (strong values → +1) |
| 193 | + const sparsity_h0 = trit_attn.sparsity(0); |
| 194 | + try std.testing.expect(sparsity_h0 < 0.5); // Less than 50% sparse (i.e., mostly active) |
| 195 | +} |
| 196 | + |
| 197 | +test "trit attention: reconstruction is consistent" { |
| 198 | + const allocator = std.testing.allocator; |
| 199 | + var trit_attn = try TritAttentionWeights.init(allocator); |
| 200 | + |
| 201 | + // Create simple float weights (all same value per head) |
| 202 | + var float_weights: [3 * 5]f32 = undefined; |
| 203 | + { |
| 204 | + var i: usize = 0; |
| 205 | + // Head 0: all positive |
| 206 | + for (0..5) |_| { |
| 207 | + float_weights[i] = 1.0; |
| 208 | + i += 1; |
| 209 | + } |
| 210 | + // Head 1: all negative |
| 211 | + for (0..5) |_| { |
| 212 | + float_weights[i] = -1.0; |
| 213 | + i += 1; |
| 214 | + } |
| 215 | + // Head 2: all weak |
| 216 | + for (0..5) |_| { |
| 217 | + float_weights[i] = 0.05; |
| 218 | + i += 1; |
| 219 | + } |
| 220 | + } |
| 221 | + |
| 222 | + trit_attn.quantizeFromFloat(&float_weights, 3, 5); |
| 223 | + |
| 224 | + // Reconstruct |
| 225 | + var reconstructed: [3 * 5]f32 = undefined; |
| 226 | + trit_attn.reconstructToFloat(&reconstructed, 3, 5); |
| 227 | + |
| 228 | + // Check Head 0: all positive values |
| 229 | + for (0..5) |pos| { |
| 230 | + try std.testing.expect(reconstructed[pos] > 0); |
| 231 | + } |
| 232 | + |
| 233 | + // Check Head 1: all negative values |
| 234 | + for (0..5) |pos| { |
| 235 | + try std.testing.expect(reconstructed[5 + pos] < 0); |
| 236 | + } |
| 237 | + |
| 238 | + // Check Head 2: mostly zeros (weak values → 0) |
| 239 | + var h2_zeros: usize = 0; |
| 240 | + for (0..5) |pos| { |
| 241 | + if (reconstructed[10 + pos] == 0) h2_zeros += 1; |
| 242 | + } |
| 243 | + try std.testing.expect(h2_zeros >= 3); // At least 3 out of 5 are zeros |
| 244 | +} |
| 245 | + |
| 246 | +test "trit attention: entropy is bounded" { |
| 247 | + const allocator = std.testing.allocator; |
| 248 | + var trit_attn = try TritAttentionWeights.init(allocator); |
| 249 | + |
| 250 | + // Maximum entropy: uniform distribution (-1, 0, +1 each occur 1/3) |
| 251 | + // H_max = -3 × (1/3) × log(1/3) ≈ 1.099 |
| 252 | + |
| 253 | + // Random float weights → quantize → check entropy |
| 254 | + var float_weights: [3 * 81]f32 = undefined; |
| 255 | + { |
| 256 | + var prng = std.Random.DefaultPrng.init(12345); |
| 257 | + const rng = prng.random(); |
| 258 | + for (&float_weights) |*w| w.* = rng.float(f32) * 2.0 - 1.0; |
| 259 | + } |
| 260 | + |
| 261 | + trit_attn.quantizeFromFloat(&float_weights, 3, 81); |
| 262 | + |
| 263 | + // Check entropy is reasonable [0, H_max] |
| 264 | + const h0 = trit_attn.headEntropy(0); |
| 265 | + const h1 = trit_attn.headEntropy(1); |
| 266 | + const h2 = trit_attn.headEntropy(2); |
| 267 | + |
| 268 | + try std.testing.expect(h0 >= 0.0 and h0 <= 1.2); |
| 269 | + try std.testing.expect(h1 >= 0.0 and h1 <= 1.2); |
| 270 | + try std.testing.expect(h2 >= 0.0 and h2 <= 1.2); |
| 271 | +} |
| 272 | + |
| 273 | +test "trit attention: scales are positive" { |
| 274 | + const allocator = std.testing.allocator; |
| 275 | + var trit_attn = try TritAttentionWeights.init(allocator); |
| 276 | + |
| 277 | + // Random weights |
| 278 | + var float_weights: [3 * 10]f32 = undefined; |
| 279 | + { |
| 280 | + var prng = std.Random.DefaultPrng.init(54321); |
| 281 | + const rng = prng.random(); |
| 282 | + for (&float_weights) |*w| w.* = rng.float(f32) * 2.0 - 1.0; |
| 283 | + } |
| 284 | + |
| 285 | + trit_attn.quantizeFromFloat(&float_weights, 3, 10); |
| 286 | + |
| 287 | + // All scales should be positive |
| 288 | + for (trit_attn.scales) |scale| { |
| 289 | + try std.testing.expect(scale > 0.0); |
| 290 | + } |
| 291 | +} |
| 292 | + |
| 293 | +test "trit attention: phi-threshold produces correct sparsity" { |
| 294 | + const allocator = std.testing.allocator; |
| 295 | + var trit_attn = try TritAttentionWeights.init(allocator); |
| 296 | + trit_attn.quantization_threshold = 0.382; // φ⁻² |
| 297 | + |
| 298 | + // Create float weights: some above, some below threshold |
| 299 | + var float_weights: [1 * 10]f32 = undefined; |
| 300 | + { |
| 301 | + var i: usize = 0; |
| 302 | + for (0..10) |pos| { |
| 303 | + // First 5: 0.1 (below threshold), Last 5: 1.0 (above threshold) |
| 304 | + float_weights[i] = if (pos < 5) 0.1 else 1.0; |
| 305 | + i += 1; |
| 306 | + } |
| 307 | + } |
| 308 | + |
| 309 | + trit_attn.quantizeFromFloat(&float_weights, 1, 10); |
| 310 | + |
| 311 | + // Check: weak values → 0, strong values → +1 |
| 312 | + var zero_count: usize = 0; |
| 313 | + var one_count: usize = 0; |
| 314 | + for (0..10) |pos| { |
| 315 | + if (trit_attn.weights[pos] == 0) zero_count += 1; |
| 316 | + if (trit_attn.weights[pos] == 1) one_count += 1; |
| 317 | + } |
| 318 | + |
| 319 | + // Should have 5 zeros and 5 ones |
| 320 | + try std.testing.expect(zero_count == 5); |
| 321 | + try std.testing.expect(one_count == 5); |
| 322 | +} |
| 323 | + |
| 324 | +test "trit attention: reconstruction with zero input" { |
| 325 | + const allocator = std.testing.allocator; |
| 326 | + var trit_attn = try TritAttentionWeights.init(allocator); |
| 327 | + |
| 328 | + // Zero input → all weights zero → scales = 1.0 |
| 329 | + var float_weights: [3 * 10]f32 = [_]f32{0.0} ** 30; |
| 330 | + |
| 331 | + trit_attn.quantizeFromFloat(&float_weights, 3, 10); |
| 332 | + |
| 333 | + // All scales should be 1.0 (minimum) |
| 334 | + for (trit_attn.scales) |scale| { |
| 335 | + try std.testing.expectApproxEqAbs(@as(f32, 1.0), scale, 1e-6); |
| 336 | + } |
| 337 | + |
| 338 | + // All weights should be 0 |
| 339 | + for (trit_attn.weights) |w| { |
| 340 | + try std.testing.expect(w == 0); |
| 341 | + } |
| 342 | +} |
0 commit comments