Skip to content

Commit 49751b5

Browse files
Antigravity Agentclaude
andcommitted
feat(hslm): add T-JEPA modules — EMA, masking, MSE loss, JEPA encoder, trainer
5 new files for T-JEPA (Ternary Joint Embedding Predictive Architecture): - ema.zig: exponential moving average for target encoder - mask.zig: span masking with ternary-aligned spans (3, 9) - mse_loss.zig: MSE loss for JEPA representation matching - tjepa.zig: context/target encoders with EMA sync - tjepa_trainer.zig: training loop with curriculum masking Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 63f61fd commit 49751b5

5 files changed

Lines changed: 1346 additions & 0 deletions

File tree

src/hslm/ema.zig

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
// T-JEPA — EMA (Exponential Moving Average) Weight Synchronization
2+
// Target encoder = EMA of online encoder shadow floats
3+
// After EMA update, target requantizes ternary weights from updated shadows
4+
//
5+
// φ² + 1/φ² = 3 = TRINITY
6+
7+
const std = @import("std");
8+
const constants = @import("constants.zig");
9+
const model_mod = @import("model.zig");
10+
11+
const EMBED_DIM = constants.EMBED_DIM;
12+
const HIDDEN_DIM = constants.HIDDEN_DIM;
13+
const VOCAB_SIZE = constants.VOCAB_SIZE;
14+
const NUM_BLOCKS = constants.NUM_BLOCKS;
15+
16+
// ═══════════════════════════════════════════════════════════════════════════════
17+
// EMA SYNC
18+
// ═══════════════════════════════════════════════════════════════════════════════
19+
20+
pub const EmaSync = struct {
21+
decay_start: f32, // 0.996 — initial decay (more online influence)
22+
decay_end: f32, // 1.0 — final decay (target freezes)
23+
24+
/// Update target shadow floats via EMA: target[i] = decay * target[i] + (1-decay) * online[i]
25+
pub fn updateShadows(target_shadow: []f32, online_shadow: []const f32, decay: f32) void {
26+
std.debug.assert(target_shadow.len == online_shadow.len);
27+
const one_minus_decay = 1.0 - decay;
28+
for (target_shadow, online_shadow) |*t, o| {
29+
t.* = decay * t.* + one_minus_decay * o;
30+
}
31+
}
32+
33+
/// Sync all shadow weights from online encoder to target encoder
34+
/// Operates on: output_shadow, per-block TNN shadows + biases, sacred attention shadows + rms_gamma
35+
pub fn syncModels(self: *const EmaSync, target: *model_mod.HSLM, online: *const model_mod.HSLM, step: u32, total_steps: u32) void {
36+
const decay = scheduledDecay(step, total_steps, self.decay_start, self.decay_end);
37+
38+
// Output projection shadows
39+
updateShadows(target.output_shadow, online.output_shadow, decay);
40+
41+
// Per-block params
42+
for (&target.blocks, online.blocks) |*t_block, o_block| {
43+
// TNN dense shadows
44+
updateShadows(t_block.tnn.shadow_up, o_block.tnn.shadow_up, decay);
45+
updateShadows(t_block.tnn.shadow_down, o_block.tnn.shadow_down, decay);
46+
updateShadows(t_block.tnn.bias_up, o_block.tnn.bias_up, decay);
47+
updateShadows(t_block.tnn.bias_down, o_block.tnn.bias_down, decay);
48+
49+
// Sacred attention shadows
50+
updateShadows(t_block.sacred_attn.shadow_q, o_block.sacred_attn.shadow_q, decay);
51+
updateShadows(t_block.sacred_attn.shadow_k, o_block.sacred_attn.shadow_k, decay);
52+
updateShadows(t_block.sacred_attn.shadow_v, o_block.sacred_attn.shadow_v, decay);
53+
updateShadows(t_block.sacred_attn.shadow_o, o_block.sacred_attn.shadow_o, decay);
54+
55+
// RMS gamma
56+
updateShadows(t_block.sacred_attn.rms_gamma, o_block.sacred_attn.rms_gamma, decay);
57+
}
58+
59+
// Embedding float table
60+
updateShadows(target.emb.float_table, online.emb.float_table, decay);
61+
}
62+
63+
/// Current decay value at given step
64+
pub fn currentDecay(self: *const EmaSync, step: u32, total_steps: u32) f32 {
65+
return scheduledDecay(step, total_steps, self.decay_start, self.decay_end);
66+
}
67+
};
68+
69+
/// Linear ramp from start to end over total_steps
70+
pub fn scheduledDecay(step: u32, total_steps: u32, start: f32, end: f32) f32 {
71+
if (total_steps == 0) return end;
72+
const t = @min(@as(f32, @floatFromInt(step)) / @as(f32, @floatFromInt(total_steps)), 1.0);
73+
return start + (end - start) * t;
74+
}
75+
76+
// ═══════════════════════════════════════════════════════════════════════════════
77+
// TESTS
78+
// ═══════════════════════════════════════════════════════════════════════════════
79+
80+
test "ema decay formula" {
81+
var target = [_]f32{ 1.0, 2.0, 3.0 };
82+
const online = [_]f32{ 0.0, 0.0, 0.0 };
83+
EmaSync.updateShadows(&target, &online, 0.996);
84+
// target[0] = 0.996 * 1.0 + 0.004 * 0.0 = 0.996
85+
try std.testing.expectApproxEqAbs(@as(f32, 0.996), target[0], 1e-5);
86+
try std.testing.expectApproxEqAbs(@as(f32, 1.992), target[1], 1e-5);
87+
try std.testing.expectApproxEqAbs(@as(f32, 2.988), target[2], 1e-5);
88+
}
89+
90+
test "ema schedule ramp" {
91+
// At step 0 → start
92+
try std.testing.expectApproxEqAbs(@as(f32, 0.996), scheduledDecay(0, 100, 0.996, 1.0), 1e-6);
93+
// At step 50 → midpoint
94+
try std.testing.expectApproxEqAbs(@as(f32, 0.998), scheduledDecay(50, 100, 0.996, 1.0), 1e-6);
95+
// At step 100 → end
96+
try std.testing.expectApproxEqAbs(@as(f32, 1.0), scheduledDecay(100, 100, 0.996, 1.0), 1e-6);
97+
// Beyond total → clamped to end
98+
try std.testing.expectApproxEqAbs(@as(f32, 1.0), scheduledDecay(200, 100, 0.996, 1.0), 1e-6);
99+
}
100+
101+
test "ema sync models" {
102+
const allocator = std.testing.allocator;
103+
104+
var online = try model_mod.HSLM.init(allocator);
105+
defer online.deinit();
106+
var target = try model_mod.HSLM.init(allocator);
107+
defer target.deinit();
108+
109+
const ema = EmaSync{ .decay_start = 0.0, .decay_end = 0.0 };
110+
// decay=0 means target = online (full copy)
111+
ema.syncModels(&target, &online, 0, 100);
112+
113+
// After decay=0 sync, target shadows should equal online shadows
114+
for (target.output_shadow, online.output_shadow) |t, o| {
115+
try std.testing.expectApproxEqAbs(t, o, 1e-6);
116+
}
117+
// Check one block
118+
for (target.blocks[0].tnn.shadow_up, online.blocks[0].tnn.shadow_up) |t, o| {
119+
try std.testing.expectApproxEqAbs(t, o, 1e-6);
120+
}
121+
}

src/hslm/mask.zig

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
// T-JEPA — Block Masking for Sequences
2+
// Contiguous span masking: harder prediction → better representations
3+
// Spans aligned to ternary powers (3^1=3, 3^2=9)
4+
//
5+
// φ² + 1/φ² = 3 = TRINITY
6+
7+
const std = @import("std");
8+
const constants = @import("constants.zig");
9+
10+
const CONTEXT_LEN = constants.CONTEXT_LEN;
11+
12+
// ═══════════════════════════════════════════════════════════════════════════════
13+
// MASK CONFIG
14+
// ═══════════════════════════════════════════════════════════════════════════════
15+
16+
pub const MaskConfig = struct {
17+
mask_ratio: f32 = 0.3, // 30% masked
18+
min_span: usize = 3, // 3^1
19+
max_span: usize = 9, // 3^2 (ctx=27 can't fit 3 spans of 27)
20+
num_spans: usize = 2, // 2 spans fit in ctx=27..81
21+
};
22+
23+
pub const MaskResult = struct {
24+
visible: [CONTEXT_LEN]bool, // true = visible, false = masked
25+
num_visible: usize,
26+
num_masked: usize,
27+
masked_positions: [CONTEXT_LEN]usize, // packed list of masked indices
28+
visible_positions: [CONTEXT_LEN]usize, // packed list of visible indices
29+
30+
pub fn init() MaskResult {
31+
return .{
32+
.visible = [_]bool{true} ** CONTEXT_LEN,
33+
.num_visible = 0,
34+
.num_masked = 0,
35+
.masked_positions = [_]usize{0} ** CONTEXT_LEN,
36+
.visible_positions = [_]usize{0} ** CONTEXT_LEN,
37+
};
38+
}
39+
};
40+
41+
// ═══════════════════════════════════════════════════════════════════════════════
42+
// MASK GENERATION
43+
// ═══════════════════════════════════════════════════════════════════════════════
44+
45+
/// Generate contiguous span mask for a sequence
46+
/// 1. Sample num_spans spans of length uniform[min_span, max_span]
47+
/// 2. Random start positions, merge overlaps
48+
/// 3. Clamp total masked ≤ seq_len * mask_ratio
49+
pub fn generateMask(seq_len: usize, config: MaskConfig, rng: std.Random) MaskResult {
50+
var result = MaskResult.init();
51+
if (seq_len == 0) return result;
52+
53+
const effective_len = @min(seq_len, CONTEXT_LEN);
54+
const max_masked = @as(usize, @intFromFloat(@as(f32, @floatFromInt(effective_len)) * config.mask_ratio));
55+
56+
// Mark all as visible initially
57+
for (0..CONTEXT_LEN) |i| {
58+
result.visible[i] = true;
59+
}
60+
61+
// Generate spans
62+
var total_masked: usize = 0;
63+
for (0..config.num_spans) |_| {
64+
if (total_masked >= max_masked) break;
65+
66+
// Sample span length
67+
const span_range = config.max_span - config.min_span + 1;
68+
const span_len = config.min_span + rng.uintLessThan(usize, span_range);
69+
const actual_span = @min(span_len, max_masked - total_masked);
70+
71+
if (actual_span == 0) break;
72+
if (effective_len <= actual_span) break;
73+
74+
// Random start position
75+
const max_start = effective_len - actual_span;
76+
const start = rng.uintLessThan(usize, max_start + 1);
77+
78+
// Mark span as masked
79+
for (start..start + actual_span) |pos| {
80+
if (result.visible[pos]) {
81+
result.visible[pos] = false;
82+
total_masked += 1;
83+
if (total_masked >= max_masked) break;
84+
}
85+
}
86+
}
87+
88+
// Build packed position arrays
89+
var vi: usize = 0;
90+
var mi: usize = 0;
91+
for (0..effective_len) |i| {
92+
if (result.visible[i]) {
93+
result.visible_positions[vi] = i;
94+
vi += 1;
95+
} else {
96+
result.masked_positions[mi] = i;
97+
mi += 1;
98+
}
99+
}
100+
result.num_visible = vi;
101+
result.num_masked = mi;
102+
103+
return result;
104+
}
105+
106+
// ═══════════════════════════════════════════════════════════════════════════════
107+
// TESTS
108+
// ═══════════════════════════════════════════════════════════════════════════════
109+
110+
test "mask valid split" {
111+
var prng = std.Random.DefaultPrng.init(42);
112+
const result = generateMask(27, .{}, prng.random());
113+
// visible + masked = seq_len
114+
try std.testing.expectEqual(@as(usize, 27), result.num_visible + result.num_masked);
115+
}
116+
117+
test "mask ratio approximate" {
118+
var prng = std.Random.DefaultPrng.init(123);
119+
// Run multiple times and check average
120+
var total_masked: usize = 0;
121+
const trials = 100;
122+
const seq_len: usize = 81;
123+
for (0..trials) |_| {
124+
const result = generateMask(seq_len, .{}, prng.random());
125+
total_masked += result.num_masked;
126+
}
127+
const avg_ratio = @as(f32, @floatFromInt(total_masked)) / @as(f32, @floatFromInt(trials * seq_len));
128+
// Should be within 20% of 0.3 → between 0.1 and 0.5
129+
try std.testing.expect(avg_ratio > 0.1);
130+
try std.testing.expect(avg_ratio < 0.5);
131+
}
132+
133+
test "mask spans contiguous" {
134+
var prng = std.Random.DefaultPrng.init(777);
135+
const result = generateMask(81, .{ .num_spans = 1, .min_span = 5, .max_span = 9 }, prng.random());
136+
// With 1 span, masked positions should be contiguous
137+
if (result.num_masked > 1) {
138+
for (0..result.num_masked - 1) |i| {
139+
const diff = result.masked_positions[i + 1] - result.masked_positions[i];
140+
try std.testing.expectEqual(@as(usize, 1), diff);
141+
}
142+
}
143+
}
144+
145+
test "mask deterministic seed" {
146+
var prng1 = std.Random.DefaultPrng.init(42);
147+
var prng2 = std.Random.DefaultPrng.init(42);
148+
const r1 = generateMask(27, .{}, prng1.random());
149+
const r2 = generateMask(27, .{}, prng2.random());
150+
try std.testing.expectEqual(r1.num_masked, r2.num_masked);
151+
for (0..r1.num_masked) |i| {
152+
try std.testing.expectEqual(r1.masked_positions[i], r2.masked_positions[i]);
153+
}
154+
}

0 commit comments

Comments
 (0)