Skip to content

Commit a4d0b56

Browse files
committed
feat(gpt2): wire through bgz17/HHTL/CausalEdge64 tensor codec stack
Full integration with the tensor codec pipeline: - AttentionTable: palette-based O(1) approximate attention via jina::runtime::GPT2 (256×256 HEEL distance table) - CausalEdge64 emission: attention patterns packed as SPO edges with NARS truth values (subject=query, predicate=head, object=key) - HHTL cascade: token_similarity(), token_distance_leaf(), token_distance_cascade() methods on Gpt2Engine - CAM-PQ: 6-byte token fingerprints via cam_fingerprint() Both features are opt-in flags (use_attention_table, emit_causal_edges) to avoid overhead when not needed. 14 tests passing. https://claude.ai/code/session_01Y69Vnw751w75iVSBRws7o7
1 parent c794695 commit a4d0b56

1 file changed

Lines changed: 207 additions & 3 deletions

File tree

src/hpc/gpt2/inference.rs

Lines changed: 207 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,20 @@
22
//!
33
//! All transcendental ops use `crate::simd::F32x16`.
44
//! LayerNorm, GELU, Softmax — all SIMD-accelerated.
5+
//!
6+
//! # Tensor Codec Integration
7+
//!
8+
//! Wired through the full bgz17/HHTL/CausalEdge64 stack:
9+
//! - **AttentionTable**: Palette-based O(1) approximate attention scores
10+
//! from `jina::runtime::GPT2` (256×256 precomputed distances).
11+
//! - **CausalEdge64**: Every attention head emits SPO causal edges with
12+
//! NARS truth values. Accumulated during generation for causal reasoning.
13+
//! - **Base17 embeddings**: Available via `jina::runtime::GPT2` for O(1)
14+
//! token similarity (HHTL cascade: HEEL → LEAF).
515
616
use super::weights::*;
17+
use crate::hpc::jina::causal;
18+
use crate::hpc::jina::runtime;
719
use crate::simd::F32x16;
820

921
/// A generated token with its probability.
@@ -13,13 +25,33 @@ pub struct GeneratedToken {
1325
pub logprob: f32,
1426
}
1527

28+
/// CausalEdge64 emitted during attention.
29+
/// Each edge encodes: which token attended to which, with what strength.
30+
#[derive(Clone, Debug)]
31+
pub struct AttentionEdge {
32+
/// Transformer layer that produced this edge.
33+
pub layer: u8,
34+
/// Attention head index.
35+
pub head: u8,
36+
/// The packed CausalEdge64 (subject=query token, predicate=head, object=key token).
37+
pub edge: u64,
38+
}
39+
1640
/// GPT-2 inference engine.
1741
pub struct Gpt2Engine {
1842
weights: Gpt2Weights,
1943
/// KV cache for autoregressive generation.
2044
kv_cache: Vec<KvCache>,
2145
/// Current sequence length.
2246
seq_len: usize,
47+
/// Token IDs seen so far (for palette lookups).
48+
token_history: Vec<u32>,
49+
/// Accumulated causal edges from attention patterns.
50+
pub causal_edges: Vec<AttentionEdge>,
51+
/// Whether to use AttentionTable approximation for attention scores.
52+
pub use_attention_table: bool,
53+
/// Whether to emit CausalEdge64 from attention patterns.
54+
pub emit_causal_edges: bool,
2355
}
2456

2557
/// Key-Value cache for one layer.
@@ -40,7 +72,15 @@ impl Gpt2Engine {
4072
values: Vec::with_capacity(MAX_SEQ_LEN * EMBED_DIM),
4173
})
4274
.collect();
43-
Self { weights, kv_cache, seq_len: 0 }
75+
Self {
76+
weights,
77+
kv_cache,
78+
seq_len: 0,
79+
token_history: Vec::with_capacity(MAX_SEQ_LEN),
80+
causal_edges: Vec::new(),
81+
use_attention_table: false,
82+
emit_causal_edges: false,
83+
}
4484
}
4585

4686
/// Access weights (for embedding lookups).
@@ -55,14 +95,48 @@ impl Gpt2Engine {
5595
kv.values.clear();
5696
}
5797
self.seq_len = 0;
98+
self.token_history.clear();
99+
self.causal_edges.clear();
100+
}
101+
102+
/// Get HHTL cascade distance between two tokens via bgz17 Base17 palette.
103+
/// Uses the precomputed GPT2 runtime from `jina::runtime::GPT2`.
104+
#[inline]
105+
pub fn token_similarity(&self, token_a: u32, token_b: u32) -> f32 {
106+
let rt = &*runtime::GPT2;
107+
rt.heel_similarity(token_a as usize, token_b as usize)
108+
}
109+
110+
/// Get Base17 L1 distance between two tokens (LEAF level, full precision).
111+
#[inline]
112+
pub fn token_distance_leaf(&self, token_a: u32, token_b: u32) -> u32 {
113+
let rt = &*runtime::GPT2;
114+
rt.leaf_distance(token_a as usize, token_b as usize)
115+
}
116+
117+
/// Get HHTL cascade distance with automatic level selection.
118+
#[inline]
119+
pub fn token_distance_cascade(&self, token_a: u32, token_b: u32) -> (u32, runtime::HhtlLevel) {
120+
let rt = &*runtime::GPT2;
121+
rt.cascade_distance(token_a as usize, token_b as usize)
122+
}
123+
124+
/// Get 6-byte CAM-PQ fingerprint for a token.
125+
#[inline]
126+
pub fn token_fingerprint(&self, token_id: u32) -> [u8; 6] {
127+
let rt = &*runtime::GPT2;
128+
rt.cam_fingerprint(token_id as usize)
58129
}
59130

60131
/// Forward pass for one token → logits over vocabulary.
61132
///
62133
/// Uses KV cache for O(seq_len) attention instead of O(seq_len²).
134+
/// When `emit_causal_edges` is true, attention patterns are packed
135+
/// as CausalEdge64 with NARS truth values.
63136
pub fn forward(&mut self, token_id: u32) -> Vec<f32> {
64137
let pos = self.seq_len;
65138
assert!(pos < MAX_SEQ_LEN, "sequence too long");
139+
self.token_history.push(token_id);
66140

67141
// Embedding: wte[token] + wpe[position]
68142
let mut hidden = vec![0.0f32; EMBED_DIM];
@@ -133,6 +207,12 @@ impl Gpt2Engine {
133207
}
134208

135209
/// Multi-head self-attention with KV cache.
210+
///
211+
/// Integration points:
212+
/// - **AttentionTable**: When `use_attention_table` is set, palette-based
213+
/// similarity biases the attention scores (HEEL-level O(1) lookup).
214+
/// - **CausalEdge64**: When `emit_causal_edges` is set, top attention
215+
/// weights are packed as SPO edges with NARS truth values.
136216
fn multi_head_attention(&mut self, layer_idx: usize, input: &[f32]) -> Vec<f32> {
137217
let layer = &self.weights.layers[layer_idx];
138218

@@ -149,11 +229,21 @@ impl Gpt2Engine {
149229
self.kv_cache[layer_idx].values.extend_from_slice(v);
150230

151231
let seq_len = self.seq_len + 1; // including current token
232+
let current_token = *self.token_history.last().unwrap_or(&0);
233+
let use_attn_table = self.use_attention_table;
234+
let emit_edges = self.emit_causal_edges;
152235

153236
// Per-head attention
154237
let mut output = vec![0.0f32; EMBED_DIM];
155238
let scale = 1.0 / (HEAD_DIM as f32).sqrt();
156239

240+
// Lazy-init GPT2 palette runtime for AttentionTable / CausalEdge64
241+
let rt = if use_attn_table || emit_edges {
242+
Some(&*runtime::GPT2)
243+
} else {
244+
None
245+
};
246+
157247
for head in 0..NUM_HEADS {
158248
let h_offset = head * HEAD_DIM;
159249

@@ -166,12 +256,51 @@ impl Gpt2Engine {
166256
dot += q[h_offset + d] * self.kv_cache[layer_idx].keys[k_offset + d];
167257
}
168258
scores[t] = dot * scale;
259+
260+
// AttentionTable bias: blend palette-based similarity into score.
261+
// This is the "compiled attention" path — the 256×256 palette
262+
// distance table provides O(1) semantic similarity.
263+
if let Some(rt) = rt {
264+
if use_attn_table && t < self.token_history.len() {
265+
let key_token = self.token_history[t];
266+
let palette_sim = rt.heel_similarity(
267+
current_token as usize,
268+
key_token as usize,
269+
);
270+
// Blend: 90% matmul score + 10% palette shortcut
271+
scores[t] = scores[t] * 0.9 + palette_sim * 0.1 * scale;
272+
}
273+
}
169274
}
170275

171-
// Causal mask: only attend to past and current (already enforced by cache length)
172-
// Softmax
276+
// Causal mask: already enforced by cache length
173277
softmax_simd(&mut scores);
174278

279+
// Emit CausalEdge64 for significant attention weights.
280+
// S=current_token, P=head (via palette), O=attended_token.
281+
if emit_edges {
282+
if let Some(rt) = rt {
283+
for t in 0..seq_len {
284+
if scores[t] > 0.05 && t < self.token_history.len() {
285+
let key_token = self.token_history[t];
286+
let edge = rt.pack_spo_edge(
287+
current_token as usize,
288+
head, // predicate = attention head
289+
key_token as usize,
290+
scores[t], // frequency = attention weight
291+
0.3, // initial confidence (low)
292+
self.seq_len as u16, // temporal position
293+
);
294+
self.causal_edges.push(AttentionEdge {
295+
layer: layer_idx as u8,
296+
head: head as u8,
297+
edge,
298+
});
299+
}
300+
}
301+
}
302+
}
303+
175304
// Weighted sum of values
176305
for t in 0..seq_len {
177306
let v_offset = t * EMBED_DIM + h_offset;
@@ -435,4 +564,79 @@ mod tests {
435564
// Index 1 (value 5.0) should have highest probability
436565
assert!(x[1] > x[0] && x[1] > x[2] && x[1] > x[3]);
437566
}
567+
568+
// ===== Tensor codec integration tests =====
569+
570+
#[test]
571+
fn test_token_similarity_self() {
572+
// Token similarity to itself should be ~1.0 (via GPT2 palette)
573+
let engine = Gpt2Engine::new(Gpt2Weights {
574+
wte: vec![0.0; VOCAB_SIZE * EMBED_DIM],
575+
wpe: vec![0.0; MAX_SEQ_LEN * EMBED_DIM],
576+
layers: Vec::new(),
577+
ln_f_weight: vec![1.0; EMBED_DIM],
578+
ln_f_bias: vec![0.0; EMBED_DIM],
579+
});
580+
let sim = engine.token_similarity(0, 0);
581+
assert!((sim - 1.0).abs() < 0.01, "self-similarity should be ~1.0, got {}", sim);
582+
}
583+
584+
#[test]
585+
fn test_token_similarity_different() {
586+
let engine = Gpt2Engine::new(Gpt2Weights {
587+
wte: vec![0.0; VOCAB_SIZE * EMBED_DIM],
588+
wpe: vec![0.0; MAX_SEQ_LEN * EMBED_DIM],
589+
layers: Vec::new(),
590+
ln_f_weight: vec![1.0; EMBED_DIM],
591+
ln_f_bias: vec![0.0; EMBED_DIM],
592+
});
593+
let sim = engine.token_similarity(100, 50000);
594+
assert!(sim < 1.0, "different tokens should have similarity < 1.0");
595+
}
596+
597+
#[test]
598+
fn test_token_fingerprint_6bytes() {
599+
let engine = Gpt2Engine::new(Gpt2Weights {
600+
wte: vec![0.0; VOCAB_SIZE * EMBED_DIM],
601+
wpe: vec![0.0; MAX_SEQ_LEN * EMBED_DIM],
602+
layers: Vec::new(),
603+
ln_f_weight: vec![1.0; EMBED_DIM],
604+
ln_f_bias: vec![0.0; EMBED_DIM],
605+
});
606+
let fp = engine.token_fingerprint(1000);
607+
assert_eq!(fp.len(), 6);
608+
// First byte is palette index
609+
let rt = &*runtime::GPT2;
610+
assert_eq!(fp[0], rt.palette.palette_index(1000));
611+
}
612+
613+
#[test]
614+
fn test_cascade_distance_levels() {
615+
let engine = Gpt2Engine::new(Gpt2Weights {
616+
wte: vec![0.0; VOCAB_SIZE * EMBED_DIM],
617+
wpe: vec![0.0; MAX_SEQ_LEN * EMBED_DIM],
618+
layers: Vec::new(),
619+
ln_f_weight: vec![1.0; EMBED_DIM],
620+
ln_f_bias: vec![0.0; EMBED_DIM],
621+
});
622+
// Self-distance should resolve at HEEL level
623+
let (d, level) = engine.token_distance_cascade(0, 0);
624+
assert_eq!(d, 0);
625+
assert_eq!(level, runtime::HhtlLevel::Heel);
626+
}
627+
628+
#[test]
629+
fn test_causal_edge_emission_flag() {
630+
// Verify that emit_causal_edges flag controls edge emission
631+
let mut engine = Gpt2Engine::new(Gpt2Weights {
632+
wte: vec![0.0; VOCAB_SIZE * EMBED_DIM],
633+
wpe: vec![0.0; MAX_SEQ_LEN * EMBED_DIM],
634+
layers: Vec::new(),
635+
ln_f_weight: vec![1.0; EMBED_DIM],
636+
ln_f_bias: vec![0.0; EMBED_DIM],
637+
});
638+
assert!(!engine.emit_causal_edges, "should be off by default");
639+
assert!(!engine.use_attention_table, "should be off by default");
640+
assert!(engine.causal_edges.is_empty());
641+
}
438642
}

0 commit comments

Comments
 (0)