Skip to content

Commit 05c9e4a

Browse files
Prometheus: per-row LayerNorm + broadcast-aware tape ops + multi-token attention
The plumbing needed to make multi-token transformer training real. Rust additions (omnimcode-core/src/interpreter.rs): (1) tape_layernorm(x, gamma, beta, eps?) — fused per-row LayerNorm Forward: normalize each row to zero mean / unit variance, scale by gamma, add beta. Backward: full LayerNorm gradient (dx with proper centered/scaled terms, dgamma, dbeta). Composing this from primitives needed broadcast sub/div that weren't on the tape; fused op is cleaner + faster. (2) tape_row_mean(x) / tape_row_sum(x) — per-row reductions [rows, cols] → [rows, 1] with element-wise backward. Building blocks for any per-row scaling. (3) tape_add / tape_sub now support row + col vector broadcast [N, C] + [1, C] (Linear's bias add) [N, C] + [N, 1] (per-row scaling) Forward picks the bigger shape; backward reduces upstream gradient back to the smaller operand's shape via new reduce_to_shape helper. This fixed a latent bias-gradient bug in the earlier transformer demo — its 11.3x loss reduction came partly from over-broadcasting bias grads. With correct broadcast reduction, the same demo gets 4.15x (still real, just honest). Prometheus additions (examples/lib/prometheus.omc): (4) prom_layernorm_forward upgraded to use tape_layernorm fused op instead of the composed (mean, sub, exp(-0.5*log(var+eps)), ...) path. Cleaner, works on multi-token inputs. (5) prom_embedding_batch(layer, token_ids[]) — multi-token lookup via [N, vocab] one-hot @ table. Differentiable into the table. (6) prom_cross_entropy_batch(logits, targets, vocab) — sum of per- position -log(softmax) for batched LM training. A/B demo (examples/prometheus_attention_ab.omc): Multi-token transformer (8-token windows), seq_len=8, d_model=16, ff=32, AdamW, cross-entropy. Two arms: A: alpha=0 (vanilla softmax attention) B: alpha=0.5 (geodesic-bias attention) 3 seeds × 250 steps each. Tests whether the PyTorch geodesic win replicates in Prometheus. The first multi-token training run worked end-to-end through the new plumbing. Single-seed result before extending to 3 seeds: vanilla=3.104 geodesic=3.119 delta=+0.46% A genuine fail-forward. Could be: single seed noise, alpha not tuned, model too small, training too short. The 3-seed run is in flight; result will land in the next commit. What matters infrastructure-wise: multi-token attention works in pure OMC now. The geodesic primitive is wired correctly (numerically identical to PyTorch). Whether it HELPS at this scale is an empirical question we can keep iterating on without re-shipping plumbing. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1 parent ea3a1e0 commit 05c9e4a

3 files changed

Lines changed: 565 additions & 32 deletions

File tree

examples/lib/prometheus.omc

Lines changed: 59 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -892,6 +892,61 @@ fn prom_embedding_params(layer) {
892892
return [dict_get(layer, "table")];
893893
}
894894

895+
# Batched embedding lookup: token_ids[] → [N, d_model] matrix.
896+
# Implemented via an [N, vocab] one-hot batch then matmul with the
897+
# embedding table. Differentiable end-to-end.
898+
fn prom_embedding_batch(layer, token_ids) {
899+
h vocab = dict_get(layer, "vocab");
900+
h table = dict_get(layer, "table");
901+
h n = arr_len(token_ids);
902+
h onehot = [];
903+
h i = 0;
904+
while i < n {
905+
h row = [];
906+
h idx = arr_get(token_ids, i);
907+
h j = 0;
908+
while j < vocab {
909+
if j == idx { arr_push(row, 1.0); }
910+
else { arr_push(row, 0.0); }
911+
j = j + 1;
912+
}
913+
arr_push(onehot, row);
914+
i = i + 1;
915+
}
916+
h onehot_const = tape_const(onehot);
917+
return tape_matmul(onehot_const, table);
918+
}
919+
920+
# Batched cross-entropy: logits is [N, vocab], targets is array of N
921+
# integer indices. Returns scalar mean loss (averaged over positions).
922+
fn prom_cross_entropy_batch(logits_id, targets, vocab) {
923+
h n = arr_len(targets);
924+
h probs = tape_softmax(logits_id);
925+
h log_probs = tape_log(probs);
926+
# Build [N, vocab] mask: -1.0 at (i, targets[i]), 0 elsewhere.
927+
h mask_rows = [];
928+
h i = 0;
929+
while i < n {
930+
h row = [];
931+
h tgt = arr_get(targets, i);
932+
h c = 0;
933+
while c < vocab {
934+
if c == tgt { arr_push(row, -1.0); }
935+
else { arr_push(row, 0.0); }
936+
c = c + 1;
937+
}
938+
arr_push(mask_rows, row);
939+
i = i + 1;
940+
}
941+
h mask = tape_const(mask_rows);
942+
h selected = tape_mul(log_probs, mask);
943+
# Mean over all cells = (sum of -log p_target) / (N * vocab).
944+
# We want per-token mean = sum / N. Use sum + divide.
945+
h s = tape_sum(selected);
946+
h scale = tape_const(1.0 / n);
947+
return tape_mul(s, scale);
948+
}
949+
895950
# ---------------------------------------------------------------------------
896951
# LayerNorm — normalize each row to zero mean / unit variance, then
897952
# scale + shift by learned gamma/beta.
@@ -923,35 +978,14 @@ fn prom_layernorm_new(d_model, rng_state) {
923978
return layer;
924979
}
925980

926-
# Forward: x is [1, d_model] (single row); subtract mean, divide by
927-
# stable std, scale + shift. The Mean op already gives us per-tensor
928-
# mean; for per-row mean we use the same op since our inputs here are
929-
# single-row.
981+
# Forward: x is [N, d_model]; per-row layer norm via the fused
982+
# tape_layernorm Rust op. Works for both single-row [1, d] and
983+
# multi-token [seq, d] shapes — same code path.
930984
fn prom_layernorm_forward(layer, x_id) {
931985
h gamma = dict_get(layer, "gamma");
932986
h beta = dict_get(layer, "beta");
933987
h eps = dict_get(layer, "eps");
934-
935-
h mean_id = tape_mean(x_id);
936-
# Broadcast mean as a const shaped like x; OMC's tape mul handles
937-
# scalar broadcast.
938-
h centered = tape_sub(x_id, mean_id);
939-
h sq = tape_mul(centered, centered);
940-
h variance = tape_mean(sq);
941-
h std_const = tape_const(eps);
942-
h denom_sq = tape_add(variance, std_const);
943-
# We need sqrt(variance); use tape_pow_int(denom_sq, ...) — but
944-
# pow_int can only do integer powers. Approximate sqrt via the
945-
# identity sqrt(x) = x^0.5: not directly available; use exp(0.5*log(x)).
946-
h log_v = tape_log(denom_sq);
947-
h half = tape_const(0.5);
948-
h half_log = tape_mul(log_v, half);
949-
h std_inv_log = tape_neg(half_log);
950-
h std_inv = tape_exp(std_inv_log); # = 1 / sqrt(variance + eps)
951-
952-
h normed = tape_mul(centered, std_inv);
953-
h scaled = tape_mul(normed, gamma);
954-
return tape_add(scaled, beta);
988+
return tape_layernorm(x_id, gamma, beta, eps);
955989
}
956990

957991
fn prom_layernorm_params(layer) {
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
# Multi-token transformer with geodesic attention A/B.
2+
#
3+
# The key experiment: does the geodesic attention bias that won 3/3
4+
# seeds in today's PyTorch experiment also help when training a
5+
# Prometheus model from scratch on real text?
6+
#
7+
# Architecture (8-token sliding window):
8+
# tokens[8]
9+
# ↓ Embedding → [8, d_model]
10+
# ↓ + CRT-PE
11+
# x
12+
# ↓ Attention (with OR without geodesic bias on positions)
13+
# ↓ + residual
14+
# ↓ LayerNorm
15+
# ↓ FFN
16+
# ↓ + residual
17+
# ↓ LayerNorm
18+
# ↓ head → [8, vocab]
19+
# logits, target the next token at every position.
20+
#
21+
# A: alpha=0 → vanilla softmax attention (geodesic OFF)
22+
# B: alpha=0.5 → geodesic-biased attention (geodesic ON)
23+
#
24+
# Stop condition: report final tail-mean loss for both arms.
25+
# If B < A, the geodesic win replicates in Prometheus too.
26+
27+
import "examples/lib/prometheus.omc";
28+
29+
fn build_vocab(text) {
30+
h seen = dict_new();
31+
h chars = [];
32+
h i = 0;
33+
while i < str_len(text) {
34+
h ch = str_slice(text, i, i + 1);
35+
if !dict_has(seen, ch) {
36+
dict_set(seen, ch, arr_len(chars));
37+
arr_push(chars, ch);
38+
}
39+
i = i + 1;
40+
}
41+
h v = dict_new();
42+
dict_set(v, "chars", chars);
43+
dict_set(v, "lookup", seen);
44+
return v;
45+
}
46+
47+
fn encode(text, vocab) {
48+
h lookup = dict_get(vocab, "lookup");
49+
h ids = [];
50+
h i = 0;
51+
while i < str_len(text) {
52+
h ch = str_slice(text, i, i + 1);
53+
arr_push(ids, dict_get(lookup, ch));
54+
i = i + 1;
55+
}
56+
return ids;
57+
}
58+
59+
fn build_model(vocab_size, d_model, ff_dim, seq_len, alpha, seed) {
60+
h emb = prom_embedding_new(vocab_size, d_model, seed);
61+
h s1 = dict_get(emb, "rng_state");
62+
h attn = prom_attention_new(d_model, seq_len, s1 + 11);
63+
dict_set(attn, "alpha", alpha); # geodesic strength
64+
h s2 = dict_get(attn, "rng_state");
65+
h ln1 = prom_layernorm_new(d_model, s2);
66+
h ff_up = prom_linear_new(d_model, ff_dim, s2 + 13);
67+
h s3 = dict_get(ff_up, "rng_state");
68+
h ff_down = prom_linear_new(ff_dim, d_model, s3);
69+
h s4 = dict_get(ff_down, "rng_state");
70+
h ln2 = prom_layernorm_new(d_model, s4);
71+
h head = prom_linear_new(d_model, vocab_size, s4 + 17);
72+
73+
h m = dict_new();
74+
dict_set(m, "emb", emb);
75+
dict_set(m, "attn", attn);
76+
dict_set(m, "ln1", ln1);
77+
dict_set(m, "ff_up", ff_up);
78+
dict_set(m, "ff_down", ff_down);
79+
dict_set(m, "ln2", ln2);
80+
dict_set(m, "head", head);
81+
dict_set(m, "alpha", alpha);
82+
return m;
83+
}
84+
85+
# Forward over an 8-token window. Returns [8, vocab] logits.
86+
fn forward_window(model, token_ids, pe_table) {
87+
h x = prom_embedding_batch(dict_get(model, "emb"), token_ids);
88+
89+
# Add CRT-PE rows for these positions (0..N).
90+
h pe_rows = [];
91+
h i = 0;
92+
while i < arr_len(token_ids) {
93+
arr_push(pe_rows, arr_get(pe_table, i));
94+
i = i + 1;
95+
}
96+
h pe_const = tape_const(pe_rows);
97+
x = tape_add(x, pe_const);
98+
99+
# Attention + residual.
100+
h attn_out = prom_attention_forward(dict_get(model, "attn"), x);
101+
h x_post_attn = tape_add(x, attn_out);
102+
103+
# LayerNorm.
104+
h normed1 = prom_layernorm_forward(dict_get(model, "ln1"), x_post_attn);
105+
106+
# FFN.
107+
h up = prom_linear_forward(dict_get(model, "ff_up"), normed1);
108+
h activated = prom_relu(up);
109+
h down = prom_linear_forward(dict_get(model, "ff_down"), activated);
110+
h x_post_ff = tape_add(x_post_attn, down);
111+
112+
# LayerNorm + head.
113+
h normed2 = prom_layernorm_forward(dict_get(model, "ln2"), x_post_ff);
114+
return prom_linear_forward(dict_get(model, "head"), normed2);
115+
}
116+
117+
fn collect_all_params(model) {
118+
h layers = [
119+
dict_get(model, "emb"),
120+
dict_get(model, "attn"),
121+
dict_get(model, "ln1"),
122+
dict_get(model, "ff_up"),
123+
dict_get(model, "ff_down"),
124+
dict_get(model, "ln2"),
125+
dict_get(model, "head"),
126+
];
127+
return prom_collect_params_v2(layers);
128+
}
129+
130+
fn train_arm(alpha, text, vocab, vocab_size, ids, seq_len, d_model,
131+
ff_dim, n_windows, lr, steps, seed) {
132+
tape_reset();
133+
h model = build_model(vocab_size, d_model, ff_dim, seq_len, alpha, seed);
134+
h params = collect_all_params(model);
135+
h opt = prom_adamw_new(params, lr, 0.9, 0.999, 1e-8, 0.0);
136+
h pe_table = prom_crt_pe_matrix(seq_len, d_model);
137+
138+
h tail_losses = [];
139+
h step = 0;
140+
while step < steps {
141+
# Pick a random-but-deterministic window start.
142+
h start = step - (step / n_windows) * n_windows;
143+
h window = [];
144+
h targets = [];
145+
h k = 0;
146+
while k < seq_len {
147+
arr_push(window, arr_get(ids, start + k));
148+
arr_push(targets, arr_get(ids, start + k + 1));
149+
k = k + 1;
150+
}
151+
h logits = forward_window(model, window, pe_table);
152+
h loss = prom_cross_entropy_batch(logits, targets, vocab_size);
153+
tape_backward(loss);
154+
prom_adamw_step(opt);
155+
if step >= steps - 10 { arr_push(tail_losses, tape_value(loss)); }
156+
step = step + 1;
157+
}
158+
h sum = 0.0;
159+
h i = 0;
160+
while i < arr_len(tail_losses) { sum = sum + arr_get(tail_losses, i); i = i + 1; }
161+
return sum / arr_len(tail_losses);
162+
}
163+
164+
fn main() {
165+
print("=== Prometheus multi-token attention: geodesic A/B ===");
166+
h text = "the quick brown fox jumps over the lazy dog and the dog sleeps in the sun";
167+
h vocab = build_vocab(text);
168+
h vocab_size = arr_len(dict_get(vocab, "chars"));
169+
h ids = encode(text, vocab);
170+
h seq_len = 8;
171+
h d_model = 16;
172+
h ff_dim = 32;
173+
h n_windows = arr_len(ids) - seq_len - 1;
174+
h lr = 0.02;
175+
h steps = 250;
176+
h seeds = [42, 7, 123];
177+
178+
print(concat_many("corpus length: ", to_string(str_len(text))));
179+
print(concat_many("vocab: ", to_string(vocab_size)));
180+
print(concat_many("seq_len: ", to_string(seq_len), " windows: ", to_string(n_windows)));
181+
print(concat_many("d_model: ", to_string(d_model), " ff: ", to_string(ff_dim)));
182+
print(concat_many("steps: ", to_string(steps), " lr: ", to_string(lr), " seeds: ", to_string(seeds)));
183+
print("");
184+
185+
h a_results = [];
186+
h b_results = [];
187+
h s = 0;
188+
while s < arr_len(seeds) {
189+
h seed = arr_get(seeds, s);
190+
h loss_a = train_arm(0.0, text, vocab, vocab_size, ids, seq_len,
191+
d_model, ff_dim, n_windows, lr, steps, seed);
192+
h loss_b = train_arm(0.5, text, vocab, vocab_size, ids, seq_len,
193+
d_model, ff_dim, n_windows, lr, steps, seed);
194+
arr_push(a_results, loss_a);
195+
arr_push(b_results, loss_b);
196+
h delta = loss_b - loss_a;
197+
h tag = "(geodesic worse)";
198+
if loss_b < loss_a { tag = "(geodesic better)"; }
199+
print(concat_many("seed ", to_string(seed),
200+
" vanilla=", to_string(loss_a),
201+
" geodesic=", to_string(loss_b),
202+
" delta=", to_string(delta), " ", tag));
203+
s = s + 1;
204+
}
205+
206+
h a_sum = 0.0;
207+
h b_sum = 0.0;
208+
h wins = 0;
209+
h i = 0;
210+
while i < arr_len(seeds) {
211+
a_sum = a_sum + arr_get(a_results, i);
212+
b_sum = b_sum + arr_get(b_results, i);
213+
if arr_get(b_results, i) < arr_get(a_results, i) { wins = wins + 1; }
214+
i = i + 1;
215+
}
216+
h a_mean = a_sum / arr_len(seeds);
217+
h b_mean = b_sum / arr_len(seeds);
218+
h rel = (b_mean - a_mean) / a_mean * 100.0;
219+
220+
print("");
221+
print("=== Multi-seed verdict ===");
222+
print(concat_many(" vanilla mean: ", to_string(a_mean)));
223+
print(concat_many(" geodesic mean: ", to_string(b_mean)));
224+
print(concat_many(" geodesic vs vanilla: ", to_string(rel), "%"));
225+
print(concat_many(" geodesic wins: ", to_string(wins), "/", to_string(arr_len(seeds))));
226+
print("");
227+
if wins >= 2 {
228+
print("[WIN] Geodesic attention helps Prometheus on majority of seeds.");
229+
print(" Cross-platform substrate-positional-bias validation.");
230+
print(" 🥂");
231+
} elif wins == 0 {
232+
print("[FAIL-FORWARD] Geodesic lost 0/3 in Prometheus.");
233+
print(" Honest negative — the PyTorch -0.4% win at distractor=0.20");
234+
print(" didn't replicate at this scale (single-block model, no");
235+
print(" distractor mix, 250 steps). Suggests either: PyTorch result");
236+
print(" was scale-specific, OR our Prometheus model needs the");
237+
print(" same training setup (much longer steps, mix of clean+noise).");
238+
} else {
239+
print("[INCONCLUSIVE] 1/3 — noise. Need more seeds or larger model.");
240+
}
241+
}
242+
243+
main();

0 commit comments

Comments
 (0)