|
| 1 | +# L0 vs L1 head-to-head in OMC — cross-runtime validation. |
| 2 | +# |
| 3 | +# PyTorch shows L1 (substrate-K) beats L0 (standard QKV) at every |
| 4 | +# scale tested. Verify the same ranking holds in OMC's tape-based |
| 5 | +# autograd. |
| 6 | +# |
| 7 | +# Setup: |
| 8 | +# - 200-char English passage (larger than the bigram-cycle but |
| 9 | +# small enough that pure-OMC training finishes in a few minutes) |
| 10 | +# - Single-block transformer (where L1's advantage is largest) |
| 11 | +# - 3 seeds, 300 steps each |
| 12 | +# - AdamW lr=0.01, d_model=16, ff=32 |
| 13 | +# |
| 14 | +# Stop condition: L1 beats L0 on at least 2/3 seeds. If yes, |
| 15 | +# the substrate-K finding is cross-runtime (OMC + PyTorch both |
| 16 | +# agree). If no, OMC has some runtime-specific behavior to debug. |
| 17 | + |
| 18 | +import "examples/lib/prometheus.omc"; |
| 19 | + |
| 20 | +fn build_vocab(text) { |
| 21 | + h seen = dict_new(); |
| 22 | + h chars = []; |
| 23 | + h i = 0; |
| 24 | + while i < str_len(text) { |
| 25 | + h ch = str_slice(text, i, i + 1); |
| 26 | + if !dict_has(seen, ch) { |
| 27 | + dict_set(seen, ch, arr_len(chars)); |
| 28 | + arr_push(chars, ch); |
| 29 | + } |
| 30 | + i = i + 1; |
| 31 | + } |
| 32 | + h v = dict_new(); |
| 33 | + dict_set(v, "chars", chars); |
| 34 | + dict_set(v, "lookup", seen); |
| 35 | + return v; |
| 36 | +} |
| 37 | + |
| 38 | +fn encode(text, vocab) { |
| 39 | + h lookup = dict_get(vocab, "lookup"); |
| 40 | + h ids = []; |
| 41 | + h i = 0; |
| 42 | + while i < str_len(text) { |
| 43 | + h ch = str_slice(text, i, i + 1); |
| 44 | + arr_push(ids, dict_get(lookup, ch)); |
| 45 | + i = i + 1; |
| 46 | + } |
| 47 | + return ids; |
| 48 | +} |
| 49 | + |
| 50 | +fn build_model(variant, vocab_size, d_model, ff_dim, seq_len, seed) { |
| 51 | + h emb = prom_embedding_new(vocab_size, d_model, seed); |
| 52 | + h s1 = dict_get(emb, "rng_state"); |
| 53 | + h attn = null; |
| 54 | + h s2 = s1 + 11; |
| 55 | + if variant == "L0" { |
| 56 | + attn = prom_attention_new(d_model, seq_len, s2); |
| 57 | + dict_set(attn, "alpha", 0.0); |
| 58 | + s2 = dict_get(attn, "rng_state"); |
| 59 | + } elif variant == "L1" { |
| 60 | + attn = prom_attention_substrate_k_new(d_model, seq_len, s2); |
| 61 | + s2 = dict_get(attn, "rng_state"); |
| 62 | + } |
| 63 | + h ln1 = prom_layernorm_new(d_model, s2); |
| 64 | + h ff_up = prom_linear_new(d_model, ff_dim, s2 + 13); |
| 65 | + h s3 = dict_get(ff_up, "rng_state"); |
| 66 | + h ff_down = prom_linear_new(ff_dim, d_model, s3); |
| 67 | + h s4 = dict_get(ff_down, "rng_state"); |
| 68 | + h ln2 = prom_layernorm_new(d_model, s4); |
| 69 | + h head = prom_linear_new(d_model, vocab_size, s4 + 17); |
| 70 | + |
| 71 | + h m = dict_new(); |
| 72 | + dict_set(m, "variant", variant); |
| 73 | + dict_set(m, "emb", emb); |
| 74 | + dict_set(m, "attn", attn); |
| 75 | + dict_set(m, "ln1", ln1); |
| 76 | + dict_set(m, "ff_up", ff_up); |
| 77 | + dict_set(m, "ff_down", ff_down); |
| 78 | + dict_set(m, "ln2", ln2); |
| 79 | + dict_set(m, "head", head); |
| 80 | + return m; |
| 81 | +} |
| 82 | + |
| 83 | +fn attn_forward(variant, attn, x_id) { |
| 84 | + if variant == "L0" { return prom_attention_forward(attn, x_id); } |
| 85 | + return prom_attention_substrate_k_forward(attn, x_id); |
| 86 | +} |
| 87 | + |
| 88 | +fn attn_params(variant, attn) { |
| 89 | + if variant == "L0" { return prom_attention_params(attn); } |
| 90 | + return prom_attention_substrate_k_params(attn); |
| 91 | +} |
| 92 | + |
| 93 | +fn forward_window(model, token_ids, pe_table) { |
| 94 | + h variant = dict_get(model, "variant"); |
| 95 | + h x = prom_embedding_batch(dict_get(model, "emb"), token_ids); |
| 96 | + |
| 97 | + h pe_rows = []; |
| 98 | + h i = 0; |
| 99 | + while i < arr_len(token_ids) { |
| 100 | + arr_push(pe_rows, arr_get(pe_table, i)); |
| 101 | + i = i + 1; |
| 102 | + } |
| 103 | + h pe_const = tape_const(pe_rows); |
| 104 | + x = tape_add(x, pe_const); |
| 105 | + |
| 106 | + h attn_out = attn_forward(variant, dict_get(model, "attn"), x); |
| 107 | + h x_post_attn = tape_add(x, attn_out); |
| 108 | + h normed1 = prom_layernorm_forward(dict_get(model, "ln1"), x_post_attn); |
| 109 | + h up = prom_linear_forward(dict_get(model, "ff_up"), normed1); |
| 110 | + h activated = prom_relu(up); |
| 111 | + h down = prom_linear_forward(dict_get(model, "ff_down"), activated); |
| 112 | + h x_post_ff = tape_add(x_post_attn, down); |
| 113 | + h normed2 = prom_layernorm_forward(dict_get(model, "ln2"), x_post_ff); |
| 114 | + return prom_linear_forward(dict_get(model, "head"), normed2); |
| 115 | +} |
| 116 | + |
| 117 | +fn collect_all_params(model) { |
| 118 | + h variant = dict_get(model, "variant"); |
| 119 | + h attn_p = attn_params(variant, dict_get(model, "attn")); |
| 120 | + h other = prom_collect_params_v2([ |
| 121 | + dict_get(model, "emb"), |
| 122 | + dict_get(model, "ln1"), |
| 123 | + dict_get(model, "ff_up"), |
| 124 | + dict_get(model, "ff_down"), |
| 125 | + dict_get(model, "ln2"), |
| 126 | + dict_get(model, "head"), |
| 127 | + ]); |
| 128 | + h out = []; |
| 129 | + h i = 0; |
| 130 | + while i < arr_len(attn_p) { arr_push(out, arr_get(attn_p, i)); i = i + 1; } |
| 131 | + i = 0; |
| 132 | + while i < arr_len(other) { arr_push(out, arr_get(other, i)); i = i + 1; } |
| 133 | + return out; |
| 134 | +} |
| 135 | + |
| 136 | +fn train_arm(variant, vocab_size, ids, seq_len, d_model, ff_dim, lr, steps, seed) { |
| 137 | + tape_reset(); |
| 138 | + h model = build_model(variant, vocab_size, d_model, ff_dim, seq_len, seed); |
| 139 | + h params = collect_all_params(model); |
| 140 | + h opt = prom_adamw_new(params, lr, 0.9, 0.999, 1e-8, 0.0); |
| 141 | + h pe_table = prom_crt_pe_matrix(seq_len, d_model); |
| 142 | + h n_windows = arr_len(ids) - seq_len - 1; |
| 143 | + |
| 144 | + h tail_losses = []; |
| 145 | + h step = 0; |
| 146 | + while step < steps { |
| 147 | + h start = step - (step / n_windows) * n_windows; |
| 148 | + h window = []; |
| 149 | + h targets = []; |
| 150 | + h k = 0; |
| 151 | + while k < seq_len { |
| 152 | + arr_push(window, arr_get(ids, start + k)); |
| 153 | + arr_push(targets, arr_get(ids, start + k + 1)); |
| 154 | + k = k + 1; |
| 155 | + } |
| 156 | + h logits = forward_window(model, window, pe_table); |
| 157 | + h loss = prom_cross_entropy_batch(logits, targets, vocab_size); |
| 158 | + tape_backward(loss); |
| 159 | + prom_adamw_step(opt); |
| 160 | + if step >= steps - 10 { arr_push(tail_losses, tape_value(loss)); } |
| 161 | + step = step + 1; |
| 162 | + } |
| 163 | + h sum = 0.0; |
| 164 | + h i = 0; |
| 165 | + while i < arr_len(tail_losses) { sum = sum + arr_get(tail_losses, i); i = i + 1; } |
| 166 | + h result = dict_new(); |
| 167 | + dict_set(result, "loss", sum / arr_len(tail_losses)); |
| 168 | + dict_set(result, "n_params", arr_len(params)); |
| 169 | + return result; |
| 170 | +} |
| 171 | + |
| 172 | +fn main() { |
| 173 | + print("=== OMC L0-vs-L1 cross-runtime validation ==="); |
| 174 | + # ~200-char English passage with real positional structure. |
| 175 | + h text = "the rain in spain falls mainly on the plain and the sun rises in the east while the moon hides behind the mountain peaks of distant lands where ancient creatures sleep in caves of silver"; |
| 176 | + print(concat_many("corpus: ", to_string(str_len(text)), " chars")); |
| 177 | + |
| 178 | + h vocab = build_vocab(text); |
| 179 | + h vocab_size = arr_len(dict_get(vocab, "chars")); |
| 180 | + h ids = encode(text, vocab); |
| 181 | + h seq_len = 8; |
| 182 | + h d_model = 16; |
| 183 | + h ff_dim = 32; |
| 184 | + h lr = 0.01; |
| 185 | + h steps = 300; |
| 186 | + h seeds = [42, 7, 123]; |
| 187 | + |
| 188 | + print(concat_many("vocab: ", to_string(vocab_size), |
| 189 | + " seq_len: ", to_string(seq_len), |
| 190 | + " d_model: ", to_string(d_model), |
| 191 | + " ff: ", to_string(ff_dim))); |
| 192 | + print(concat_many("steps: ", to_string(steps), " lr: ", to_string(lr))); |
| 193 | + print(""); |
| 194 | + |
| 195 | + h l0_losses = []; |
| 196 | + h l1_losses = []; |
| 197 | + h l0_params = 0; |
| 198 | + h l1_params = 0; |
| 199 | + h s = 0; |
| 200 | + while s < arr_len(seeds) { |
| 201 | + h seed = arr_get(seeds, s); |
| 202 | + h r0 = train_arm("L0", vocab_size, ids, seq_len, d_model, ff_dim, lr, steps, seed); |
| 203 | + h r1 = train_arm("L1", vocab_size, ids, seq_len, d_model, ff_dim, lr, steps, seed); |
| 204 | + arr_push(l0_losses, dict_get(r0, "loss")); |
| 205 | + arr_push(l1_losses, dict_get(r1, "loss")); |
| 206 | + l0_params = dict_get(r0, "n_params"); |
| 207 | + l1_params = dict_get(r1, "n_params"); |
| 208 | + h delta = dict_get(r1, "loss") - dict_get(r0, "loss"); |
| 209 | + h tag = "L0 better"; |
| 210 | + if dict_get(r1, "loss") < dict_get(r0, "loss") { tag = "L1 better"; } |
| 211 | + print(concat_many("seed ", to_string(seed), |
| 212 | + " L0=", to_string(dict_get(r0, "loss")), |
| 213 | + " L1=", to_string(dict_get(r1, "loss")), |
| 214 | + " delta=", to_string(delta), " ", tag)); |
| 215 | + s = s + 1; |
| 216 | + } |
| 217 | + print(""); |
| 218 | + |
| 219 | + h l0_sum = 0.0; |
| 220 | + h l1_sum = 0.0; |
| 221 | + h wins = 0; |
| 222 | + h i = 0; |
| 223 | + while i < arr_len(seeds) { |
| 224 | + l0_sum = l0_sum + arr_get(l0_losses, i); |
| 225 | + l1_sum = l1_sum + arr_get(l1_losses, i); |
| 226 | + if arr_get(l1_losses, i) < arr_get(l0_losses, i) { wins = wins + 1; } |
| 227 | + i = i + 1; |
| 228 | + } |
| 229 | + h l0_mean = l0_sum / arr_len(seeds); |
| 230 | + h l1_mean = l1_sum / arr_len(seeds); |
| 231 | + h rel = (l1_mean - l0_mean) / l0_mean * 100.0; |
| 232 | + |
| 233 | + print("=== Cross-runtime verdict ==="); |
| 234 | + print(concat_many("L0 params: ", to_string(l0_params), " L1 params: ", to_string(l1_params))); |
| 235 | + print(concat_many("L0 mean: ", to_string(l0_mean))); |
| 236 | + print(concat_many("L1 mean: ", to_string(l1_mean))); |
| 237 | + print(concat_many("L1 vs L0: ", to_string(rel), "% wins: ", to_string(wins), "/", to_string(arr_len(seeds)))); |
| 238 | + print(""); |
| 239 | + if wins >= 2 { |
| 240 | + print("[CROSS-RUNTIME WIN] OMC tape produces the same L1-beats-L0 result"); |
| 241 | + print(" as PyTorch. The substrate-K finding holds across:"); |
| 242 | + print(" - OMC tape autograd"); |
| 243 | + print(" - PyTorch torch.autograd"); |
| 244 | + print(" Same architecture, same direction. Real result."); |
| 245 | + } else { |
| 246 | + print("[CROSS-RUNTIME MISMATCH] OMC didn't replicate L1's advantage."); |
| 247 | + print(" Investigate OMC-specific behavior:"); |
| 248 | + print(" - tape arithmetic precision"); |
| 249 | + print(" - AdamW state representation"); |
| 250 | + print(" - prom_attention_substrate_k_forward correctness"); |
| 251 | + } |
| 252 | +} |
| 253 | + |
| 254 | +main(); |
0 commit comments