|
| 1 | +// T-JEPA — EMA (Exponential Moving Average) Weight Synchronization |
| 2 | +// Target encoder = EMA of online encoder shadow floats |
| 3 | +// After EMA update, target requantizes ternary weights from updated shadows |
| 4 | +// |
| 5 | +// φ² + 1/φ² = 3 = TRINITY |
| 6 | + |
| 7 | +const std = @import("std"); |
| 8 | +const constants = @import("constants.zig"); |
| 9 | +const model_mod = @import("model.zig"); |
| 10 | + |
| 11 | +const EMBED_DIM = constants.EMBED_DIM; |
| 12 | +const HIDDEN_DIM = constants.HIDDEN_DIM; |
| 13 | +const VOCAB_SIZE = constants.VOCAB_SIZE; |
| 14 | +const NUM_BLOCKS = constants.NUM_BLOCKS; |
| 15 | + |
| 16 | +// ═══════════════════════════════════════════════════════════════════════════════ |
| 17 | +// EMA SYNC |
| 18 | +// ═══════════════════════════════════════════════════════════════════════════════ |
| 19 | + |
| 20 | +pub const EmaSync = struct { |
| 21 | + decay_start: f32, // 0.996 — initial decay (more online influence) |
| 22 | + decay_end: f32, // 1.0 — final decay (target freezes) |
| 23 | + |
| 24 | + /// Update target shadow floats via EMA: target[i] = decay * target[i] + (1-decay) * online[i] |
| 25 | + pub fn updateShadows(target_shadow: []f32, online_shadow: []const f32, decay: f32) void { |
| 26 | + std.debug.assert(target_shadow.len == online_shadow.len); |
| 27 | + const one_minus_decay = 1.0 - decay; |
| 28 | + for (target_shadow, online_shadow) |*t, o| { |
| 29 | + t.* = decay * t.* + one_minus_decay * o; |
| 30 | + } |
| 31 | + } |
| 32 | + |
| 33 | + /// Sync all shadow weights from online encoder to target encoder |
| 34 | + /// Operates on: output_shadow, per-block TNN shadows + biases, sacred attention shadows + rms_gamma |
| 35 | + pub fn syncModels(self: *const EmaSync, target: *model_mod.HSLM, online: *const model_mod.HSLM, step: u32, total_steps: u32) void { |
| 36 | + const decay = scheduledDecay(step, total_steps, self.decay_start, self.decay_end); |
| 37 | + |
| 38 | + // Output projection shadows |
| 39 | + updateShadows(target.output_shadow, online.output_shadow, decay); |
| 40 | + |
| 41 | + // Per-block params |
| 42 | + for (&target.blocks, online.blocks) |*t_block, o_block| { |
| 43 | + // TNN dense shadows |
| 44 | + updateShadows(t_block.tnn.shadow_up, o_block.tnn.shadow_up, decay); |
| 45 | + updateShadows(t_block.tnn.shadow_down, o_block.tnn.shadow_down, decay); |
| 46 | + updateShadows(t_block.tnn.bias_up, o_block.tnn.bias_up, decay); |
| 47 | + updateShadows(t_block.tnn.bias_down, o_block.tnn.bias_down, decay); |
| 48 | + |
| 49 | + // Sacred attention shadows |
| 50 | + updateShadows(t_block.sacred_attn.shadow_q, o_block.sacred_attn.shadow_q, decay); |
| 51 | + updateShadows(t_block.sacred_attn.shadow_k, o_block.sacred_attn.shadow_k, decay); |
| 52 | + updateShadows(t_block.sacred_attn.shadow_v, o_block.sacred_attn.shadow_v, decay); |
| 53 | + updateShadows(t_block.sacred_attn.shadow_o, o_block.sacred_attn.shadow_o, decay); |
| 54 | + |
| 55 | + // RMS gamma |
| 56 | + updateShadows(t_block.sacred_attn.rms_gamma, o_block.sacred_attn.rms_gamma, decay); |
| 57 | + } |
| 58 | + |
| 59 | + // Embedding float table |
| 60 | + updateShadows(target.emb.float_table, online.emb.float_table, decay); |
| 61 | + } |
| 62 | + |
| 63 | + /// Current decay value at given step |
| 64 | + pub fn currentDecay(self: *const EmaSync, step: u32, total_steps: u32) f32 { |
| 65 | + return scheduledDecay(step, total_steps, self.decay_start, self.decay_end); |
| 66 | + } |
| 67 | +}; |
| 68 | + |
| 69 | +/// Linear ramp from start to end over total_steps |
| 70 | +pub fn scheduledDecay(step: u32, total_steps: u32, start: f32, end: f32) f32 { |
| 71 | + if (total_steps == 0) return end; |
| 72 | + const t = @min(@as(f32, @floatFromInt(step)) / @as(f32, @floatFromInt(total_steps)), 1.0); |
| 73 | + return start + (end - start) * t; |
| 74 | +} |
| 75 | + |
| 76 | +// ═══════════════════════════════════════════════════════════════════════════════ |
| 77 | +// TESTS |
| 78 | +// ═══════════════════════════════════════════════════════════════════════════════ |
| 79 | + |
| 80 | +test "ema decay formula" { |
| 81 | + var target = [_]f32{ 1.0, 2.0, 3.0 }; |
| 82 | + const online = [_]f32{ 0.0, 0.0, 0.0 }; |
| 83 | + EmaSync.updateShadows(&target, &online, 0.996); |
| 84 | + // target[0] = 0.996 * 1.0 + 0.004 * 0.0 = 0.996 |
| 85 | + try std.testing.expectApproxEqAbs(@as(f32, 0.996), target[0], 1e-5); |
| 86 | + try std.testing.expectApproxEqAbs(@as(f32, 1.992), target[1], 1e-5); |
| 87 | + try std.testing.expectApproxEqAbs(@as(f32, 2.988), target[2], 1e-5); |
| 88 | +} |
| 89 | + |
| 90 | +test "ema schedule ramp" { |
| 91 | + // At step 0 → start |
| 92 | + try std.testing.expectApproxEqAbs(@as(f32, 0.996), scheduledDecay(0, 100, 0.996, 1.0), 1e-6); |
| 93 | + // At step 50 → midpoint |
| 94 | + try std.testing.expectApproxEqAbs(@as(f32, 0.998), scheduledDecay(50, 100, 0.996, 1.0), 1e-6); |
| 95 | + // At step 100 → end |
| 96 | + try std.testing.expectApproxEqAbs(@as(f32, 1.0), scheduledDecay(100, 100, 0.996, 1.0), 1e-6); |
| 97 | + // Beyond total → clamped to end |
| 98 | + try std.testing.expectApproxEqAbs(@as(f32, 1.0), scheduledDecay(200, 100, 0.996, 1.0), 1e-6); |
| 99 | +} |
| 100 | + |
| 101 | +test "ema sync models" { |
| 102 | + const allocator = std.testing.allocator; |
| 103 | + |
| 104 | + var online = try model_mod.HSLM.init(allocator); |
| 105 | + defer online.deinit(); |
| 106 | + var target = try model_mod.HSLM.init(allocator); |
| 107 | + defer target.deinit(); |
| 108 | + |
| 109 | + const ema = EmaSync{ .decay_start = 0.0, .decay_end = 0.0 }; |
| 110 | + // decay=0 means target = online (full copy) |
| 111 | + ema.syncModels(&target, &online, 0, 100); |
| 112 | + |
| 113 | + // After decay=0 sync, target shadows should equal online shadows |
| 114 | + for (target.output_shadow, online.output_shadow) |t, o| { |
| 115 | + try std.testing.expectApproxEqAbs(t, o, 1e-6); |
| 116 | + } |
| 117 | + // Check one block |
| 118 | + for (target.blocks[0].tnn.shadow_up, online.blocks[0].tnn.shadow_up) |t, o| { |
| 119 | + try std.testing.expectApproxEqAbs(t, o, 1e-6); |
| 120 | + } |
| 121 | +} |
0 commit comments