|
| 1 | +# MH + Q6 compound test (#3 — validates v0.8.8 finding in multi-head setting). |
| 2 | +# |
| 3 | +# v0.8.5 saw MH at d_model=32 win -0.25% vs SH (single-head). v0.8.8 saw |
| 4 | +# Q6 push attention 8.31x toward substrate positions after training. If |
| 5 | +# Q6 sculpts attention per-head, the MH+Q6 combo should beat plain MH by |
| 6 | +# more than the SH+Q6 combo beat plain SH. |
| 7 | +# |
| 8 | +# Four arms at d_model=64, n_heads=4, 3 seeds, 400 steps: |
| 9 | +# A. MH off (substrate-K + S-MOD + V, Q6 off) |
| 10 | +# B. MH+Q6 fused (same + Q6 fused) |
| 11 | +# C. SH off (single-head, Q6 off) — reference |
| 12 | +# D. SH+Q6 fused — reference |
| 13 | + |
| 14 | +import "examples/lib/prometheus.omc"; |
| 15 | + |
| 16 | +fn build_vocab(text) { |
| 17 | + h seen = dict_new(); |
| 18 | + h chars = []; |
| 19 | + h i = 0; |
| 20 | + while i < str_len(text) { |
| 21 | + h ch = str_slice(text, i, i + 1); |
| 22 | + if !dict_has(seen, ch) { dict_set(seen, ch, arr_len(chars)); arr_push(chars, ch); } |
| 23 | + i = i + 1; |
| 24 | + } |
| 25 | + h v = dict_new(); |
| 26 | + dict_set(v, "chars", chars); |
| 27 | + dict_set(v, "lookup", seen); |
| 28 | + return v; |
| 29 | +} |
| 30 | + |
| 31 | +fn encode(text, vocab) { |
| 32 | + h lookup = dict_get(vocab, "lookup"); |
| 33 | + h ids = []; |
| 34 | + h i = 0; |
| 35 | + while i < str_len(text) { |
| 36 | + h ch = str_slice(text, i, i + 1); |
| 37 | + arr_push(ids, dict_get(lookup, ch)); |
| 38 | + i = i + 1; |
| 39 | + } |
| 40 | + return ids; |
| 41 | +} |
| 42 | + |
| 43 | +fn build_model(arm, vocab_size, d_model, ff_dim, seq_len, n_heads, seed) { |
| 44 | + h emb = prom_embedding_new(vocab_size, d_model, seed); |
| 45 | + h s1 = dict_get(emb, "rng_state"); |
| 46 | + h attn = null; |
| 47 | + h s2 = s1 + 11; |
| 48 | + if arm == "SH" { |
| 49 | + attn = prom_attention_substrate_k_new(d_model, seq_len, s2); |
| 50 | + s2 = dict_get(attn, "rng_state"); |
| 51 | + } elif arm == "SHQ6" { |
| 52 | + attn = prom_attention_substrate_k_new(d_model, seq_len, s2); |
| 53 | + dict_set(attn, "q6_mode", "fused"); |
| 54 | + s2 = dict_get(attn, "rng_state"); |
| 55 | + } elif arm == "MH" { |
| 56 | + attn = prom_attention_substrate_k_mh_new(d_model, seq_len, n_heads, s2); |
| 57 | + s2 = dict_get(attn, "rng_state"); |
| 58 | + } else { # MHQ6 |
| 59 | + attn = prom_attention_substrate_k_mh_new(d_model, seq_len, n_heads, s2); |
| 60 | + dict_set(attn, "q6_mode", "fused"); |
| 61 | + s2 = dict_get(attn, "rng_state"); |
| 62 | + } |
| 63 | + h ln1 = prom_layernorm_new(d_model, s2); |
| 64 | + h ff_up = prom_linear_new(d_model, ff_dim, s2 + 13); |
| 65 | + h s3 = dict_get(ff_up, "rng_state"); |
| 66 | + h ff_down = prom_linear_new(ff_dim, d_model, s3); |
| 67 | + h s4 = dict_get(ff_down, "rng_state"); |
| 68 | + h ln2 = prom_layernorm_new(d_model, s4); |
| 69 | + h head = prom_linear_new(d_model, vocab_size, s4 + 17); |
| 70 | + h m = dict_new(); |
| 71 | + dict_set(m, "arm", arm); |
| 72 | + dict_set(m, "emb", emb); |
| 73 | + dict_set(m, "attn", attn); |
| 74 | + dict_set(m, "ln1", ln1); |
| 75 | + dict_set(m, "ff_up", ff_up); |
| 76 | + dict_set(m, "ff_down", ff_down); |
| 77 | + dict_set(m, "ln2", ln2); |
| 78 | + dict_set(m, "head", head); |
| 79 | + return m; |
| 80 | +} |
| 81 | + |
| 82 | +fn attn_forward(arm, attn, x_id) { |
| 83 | + if arm == "MH" { return prom_attention_substrate_k_mh_forward(attn, x_id); } |
| 84 | + if arm == "MHQ6" { return prom_attention_substrate_k_mh_forward(attn, x_id); } |
| 85 | + return prom_attention_substrate_k_forward(attn, x_id); |
| 86 | +} |
| 87 | + |
| 88 | +fn attn_params(arm, attn) { |
| 89 | + if arm == "MH" { return prom_attention_substrate_k_mh_params(attn); } |
| 90 | + if arm == "MHQ6" { return prom_attention_substrate_k_mh_params(attn); } |
| 91 | + return prom_attention_substrate_k_params(attn); |
| 92 | +} |
| 93 | + |
| 94 | +fn forward_window(model, token_ids, pe_table) { |
| 95 | + h arm = dict_get(model, "arm"); |
| 96 | + h x = prom_embedding_batch(dict_get(model, "emb"), token_ids); |
| 97 | + h pe_rows = []; |
| 98 | + h i = 0; |
| 99 | + while i < arr_len(token_ids) { arr_push(pe_rows, arr_get(pe_table, i)); i = i + 1; } |
| 100 | + x = tape_add(x, tape_const(pe_rows)); |
| 101 | + h attn_out = attn_forward(arm, dict_get(model, "attn"), x); |
| 102 | + h x_post = tape_add(x, attn_out); |
| 103 | + h n1 = prom_layernorm_forward(dict_get(model, "ln1"), x_post); |
| 104 | + h up = prom_linear_forward(dict_get(model, "ff_up"), n1); |
| 105 | + h down = prom_linear_forward(dict_get(model, "ff_down"), prom_relu(up)); |
| 106 | + h x_ff = tape_add(x_post, down); |
| 107 | + h n2 = prom_layernorm_forward(dict_get(model, "ln2"), x_ff); |
| 108 | + return prom_linear_forward(dict_get(model, "head"), n2); |
| 109 | +} |
| 110 | + |
| 111 | +fn collect_all(model) { |
| 112 | + h arm = dict_get(model, "arm"); |
| 113 | + h attn_p = attn_params(arm, dict_get(model, "attn")); |
| 114 | + h other = prom_collect_params_v2([ |
| 115 | + dict_get(model, "emb"), |
| 116 | + dict_get(model, "ln1"), |
| 117 | + dict_get(model, "ff_up"), |
| 118 | + dict_get(model, "ff_down"), |
| 119 | + dict_get(model, "ln2"), |
| 120 | + dict_get(model, "head"), |
| 121 | + ]); |
| 122 | + h out = []; |
| 123 | + h i = 0; |
| 124 | + while i < arr_len(attn_p) { arr_push(out, arr_get(attn_p, i)); i = i + 1; } |
| 125 | + i = 0; |
| 126 | + while i < arr_len(other) { arr_push(out, arr_get(other, i)); i = i + 1; } |
| 127 | + return out; |
| 128 | +} |
| 129 | + |
| 130 | +fn train(arm, vocab_size, ids, seq_len, d_model, ff_dim, n_heads, lr, steps, seed) { |
| 131 | + tape_reset(); |
| 132 | + h model = build_model(arm, vocab_size, d_model, ff_dim, seq_len, n_heads, seed); |
| 133 | + h params = collect_all(model); |
| 134 | + h opt = prom_adamw_new(params, lr, 0.9, 0.999, 1e-8, 0.0); |
| 135 | + h pe_table = prom_crt_pe_matrix(seq_len, d_model); |
| 136 | + h n_windows = arr_len(ids) - seq_len - 1; |
| 137 | + h tail = []; |
| 138 | + h step = 0; |
| 139 | + while step < steps { |
| 140 | + h start = step - (step / n_windows) * n_windows; |
| 141 | + h window = []; |
| 142 | + h targets = []; |
| 143 | + h k = 0; |
| 144 | + while k < seq_len { |
| 145 | + arr_push(window, arr_get(ids, start + k)); |
| 146 | + arr_push(targets, arr_get(ids, start + k + 1)); |
| 147 | + k = k + 1; |
| 148 | + } |
| 149 | + h logits = forward_window(model, window, pe_table); |
| 150 | + h loss = prom_cross_entropy_batch(logits, targets, vocab_size); |
| 151 | + tape_backward(loss); |
| 152 | + prom_adamw_step(opt); |
| 153 | + if step >= steps - 30 { arr_push(tail, tape_value(loss)); } |
| 154 | + step = step + 1; |
| 155 | + } |
| 156 | + h s = 0.0; h i = 0; |
| 157 | + while i < arr_len(tail) { s = s + arr_get(tail, i); i = i + 1; } |
| 158 | + return s / arr_len(tail); |
| 159 | +} |
| 160 | + |
| 161 | +fn mean_arr(xs) { |
| 162 | + h s = 0.0; h i = 0; |
| 163 | + while i < arr_len(xs) { s = s + arr_get(xs, i); i = i + 1; } |
| 164 | + return s / arr_len(xs); |
| 165 | +} |
| 166 | + |
| 167 | +fn main() { |
| 168 | + print("=== MH+Q6 compound test (#3) ==="); |
| 169 | + h text = "the rain in spain falls mainly on the plain and the sun rises in the east while the moon hides behind the mountain peaks of distant lands where ancient creatures sleep in caves of silver beneath the stars"; |
| 170 | + h vocab = build_vocab(text); |
| 171 | + h vocab_size = arr_len(dict_get(vocab, "chars")); |
| 172 | + h ids = encode(text, vocab); |
| 173 | + h seq_len = 16; |
| 174 | + h d_model = 32; |
| 175 | + h ff_dim = 64; |
| 176 | + h n_heads = 4; |
| 177 | + h lr = 0.005; |
| 178 | + h steps = 250; |
| 179 | + h seeds = [42, 7, 123]; |
| 180 | + |
| 181 | + print(concat_many("d_model=", to_string(d_model), |
| 182 | + " n_heads=", to_string(n_heads), |
| 183 | + " steps=", to_string(steps), |
| 184 | + " seeds=", to_string(arr_len(seeds)))); |
| 185 | + print(""); |
| 186 | + |
| 187 | + h arms = ["SH", "SHQ6", "MH", "MHQ6"]; |
| 188 | + h labels = dict_new(); |
| 189 | + dict_set(labels, "SH", "SH "); |
| 190 | + dict_set(labels, "SHQ6", "SH + Q6 "); |
| 191 | + dict_set(labels, "MH", "MH (4h) "); |
| 192 | + dict_set(labels, "MHQ6", "MH (4h) + Q6"); |
| 193 | + |
| 194 | + h results = dict_new(); |
| 195 | + h ai = 0; |
| 196 | + while ai < arr_len(arms) { |
| 197 | + h arm = arr_get(arms, ai); |
| 198 | + h losses = []; |
| 199 | + h si = 0; |
| 200 | + while si < arr_len(seeds) { |
| 201 | + h seed = arr_get(seeds, si); |
| 202 | + h L = train(arm, vocab_size, ids, seq_len, d_model, ff_dim, n_heads, lr, steps, seed); |
| 203 | + arr_push(losses, L); |
| 204 | + si = si + 1; |
| 205 | + } |
| 206 | + dict_set(results, arm, losses); |
| 207 | + h mu = mean_arr(losses); |
| 208 | + print(concat_many(dict_get(labels, arm), " mean=", to_string(mu))); |
| 209 | + ai = ai + 1; |
| 210 | + } |
| 211 | + |
| 212 | + print(""); |
| 213 | + print("=== compound analysis ==="); |
| 214 | + h sh_mu = mean_arr(dict_get(results, "SH")); |
| 215 | + h shq6_mu = mean_arr(dict_get(results, "SHQ6")); |
| 216 | + h mh_mu = mean_arr(dict_get(results, "MH")); |
| 217 | + h mhq6_mu = mean_arr(dict_get(results, "MHQ6")); |
| 218 | + print(concat_many("SH→SHQ6 Δ=", to_string(shq6_mu - sh_mu), |
| 219 | + " (", to_string((shq6_mu - sh_mu) / sh_mu * 100.0), "%)")); |
| 220 | + print(concat_many("MH→MHQ6 Δ=", to_string(mhq6_mu - mh_mu), |
| 221 | + " (", to_string((mhq6_mu - mh_mu) / mh_mu * 100.0), "%)")); |
| 222 | + print(concat_many("SH→MH Δ=", to_string(mh_mu - sh_mu), |
| 223 | + " (", to_string((mh_mu - sh_mu) / sh_mu * 100.0), "%)")); |
| 224 | + print(concat_many("SH→MHQ6 Δ=", to_string(mhq6_mu - sh_mu), |
| 225 | + " (", to_string((mhq6_mu - sh_mu) / sh_mu * 100.0), "%) ← compound")); |
| 226 | +} |
| 227 | + |
| 228 | +main(); |
0 commit comments