@@ -316,6 +316,58 @@ pub const HSLM = struct {
316316 }
317317 }
318318
319+ /// Forward pass returning hidden states (before output projection) for all positions.
320+ /// Used by T-JEPA to extract representations.
321+ pub fn forwardHidden (self : * Self , tokens : []const u16 , hidden_out : []f32 ) void {
322+ const seq_len = @min (tokens .len , CONTEXT_LEN );
323+
324+ var float_seq : [CONTEXT_LEN * EMBED_DIM ]f32 = undefined ;
325+ var trit_seq : [CONTEXT_LEN * VSA_DIM ]i8 = undefined ;
326+ self .emb .embedSequence (tokens [0.. seq_len ], & float_seq , & trit_seq );
327+
328+ var cur_float : [CONTEXT_LEN * EMBED_DIM ]f32 = float_seq ;
329+ var cur_trit : [CONTEXT_LEN * VSA_DIM ]i8 = trit_seq ;
330+ var next_float : [CONTEXT_LEN * EMBED_DIM ]f32 = undefined ;
331+ var next_trit : [CONTEXT_LEN * VSA_DIM ]i8 = undefined ;
332+
333+ for (& self .blocks ) | * block | {
334+ block .sacred_attn .resetCache ();
335+ for (0.. seq_len ) | pos | {
336+ const f_off = pos * EMBED_DIM ;
337+ const t_off = pos * VSA_DIM ;
338+ block .forward (
339+ pos ,
340+ cur_float [f_off .. f_off + EMBED_DIM ],
341+ cur_trit [0 .. (pos + 1 ) * VSA_DIM ],
342+ next_float [f_off .. f_off + EMBED_DIM ],
343+ next_trit [t_off .. t_off + VSA_DIM ],
344+ );
345+ }
346+ cur_float = next_float ;
347+ cur_trit = next_trit ;
348+ }
349+
350+ // Copy hidden states (no output projection)
351+ @memcpy (hidden_out [0 .. seq_len * EMBED_DIM ], cur_float [0 .. seq_len * EMBED_DIM ]);
352+ }
353+
354+ /// Backward pass through blocks only (no output projection).
355+ /// Used by T-JEPA where gradient comes from representation loss, not logits.
356+ pub fn backwardHidden (self : * Self , grad_hidden : []const f32 ) void {
357+ var grad_current : [EMBED_DIM ]f32 = undefined ;
358+ @memcpy (& grad_current , grad_hidden [0.. EMBED_DIM ]);
359+ var grad_next : [EMBED_DIM ]f32 = undefined ;
360+
361+ var block_idx : usize = NUM_BLOCKS ;
362+ while (block_idx > 0 ) {
363+ block_idx -= 1 ;
364+ self .blocks [block_idx ].tnn .backward (& grad_current , & grad_next );
365+ var grad_attn_input : [EMBED_DIM ]f32 = undefined ;
366+ self .blocks [block_idx ].sacred_attn .backward (& grad_next , & grad_attn_input );
367+ grad_current = grad_attn_input ;
368+ }
369+ }
370+
319371 /// Generate next token (greedy)
320372 pub fn generate (self : * Self , tokens : []const u16 ) u16 {
321373 var logits : [VOCAB_SIZE ]f32 = undefined ;
0 commit comments