|
| 1 | +# v0.8.11 reformulations of the v0.8.10 substrate-aware backward |
| 2 | +# falsification. Same hypothesis (substrate as gradient-flow |
| 3 | +# regularizer), three different applications: |
| 4 | +# |
| 5 | +# R1 decay-alpha: alpha=0.5 → 0.0 linearly over training (warm start) |
| 6 | +# R2 FF-only: apply substrate gm to FF up/down weights, not Q/V |
| 7 | +# R4 scale=1024: gentler bias (finer attractor grid → less coarse pull) |
| 8 | +# |
| 9 | +# Plus baseline + v0.8.10 reference. 3 seeds, 250 steps, d_model=32. |
| 10 | + |
| 11 | +import "examples/lib/prometheus.omc"; |
| 12 | + |
| 13 | +fn build_vocab(text) { |
| 14 | + h seen = dict_new(); |
| 15 | + h chars = []; |
| 16 | + h i = 0; |
| 17 | + while i < str_len(text) { |
| 18 | + h ch = str_slice(text, i, i + 1); |
| 19 | + if !dict_has(seen, ch) { dict_set(seen, ch, arr_len(chars)); arr_push(chars, ch); } |
| 20 | + i = i + 1; |
| 21 | + } |
| 22 | + h v = dict_new(); |
| 23 | + dict_set(v, "chars", chars); |
| 24 | + dict_set(v, "lookup", seen); |
| 25 | + return v; |
| 26 | +} |
| 27 | + |
| 28 | +fn encode(text, vocab) { |
| 29 | + h lookup = dict_get(vocab, "lookup"); |
| 30 | + h ids = []; |
| 31 | + h i = 0; |
| 32 | + while i < str_len(text) { |
| 33 | + h ch = str_slice(text, i, i + 1); |
| 34 | + arr_push(ids, dict_get(lookup, ch)); |
| 35 | + i = i + 1; |
| 36 | + } |
| 37 | + return ids; |
| 38 | +} |
| 39 | + |
| 40 | +# attn forward with substrate grad mod on Q and V |
| 41 | +fn attn_forward_gm_qv(layer, x_id, gm_scale, gm_alpha) { |
| 42 | + h Q_w = dict_get(layer, "Q"); |
| 43 | + h V_w = dict_get(layer, "V"); |
| 44 | + h K_const = dict_get(layer, "K_const"); |
| 45 | + h smod_alpha = dict_get(layer, "smod_alpha"); |
| 46 | + h v_scale = dict_get(layer, "v_resample_scale"); |
| 47 | + if v_scale == null { v_scale = 0.0; } |
| 48 | + h Q_mod = tape_substrate_grad_mod(Q_w, gm_scale, gm_alpha); |
| 49 | + h V_mod = tape_substrate_grad_mod(V_w, gm_scale, gm_alpha); |
| 50 | + h q = tape_matmul(x_id, Q_mod); |
| 51 | + h v_raw = tape_matmul(x_id, V_mod); |
| 52 | + h v = prom_substrate_resample(v_raw, v_scale); |
| 53 | + h k = tape_const(K_const); |
| 54 | + h kt = tape_transpose(k); |
| 55 | + h scores = tape_matmul(q, kt); |
| 56 | + h attn = prom_substrate_softmax(scores, smod_alpha); |
| 57 | + return tape_matmul(attn, v); |
| 58 | +} |
| 59 | + |
| 60 | +# linear forward with substrate grad mod on W |
| 61 | +fn linear_forward_gm(layer, x_id, gm_scale, gm_alpha) { |
| 62 | + h W = dict_get(layer, "W"); |
| 63 | + h b = dict_get(layer, "b"); |
| 64 | + h W_mod = tape_substrate_grad_mod(W, gm_scale, gm_alpha); |
| 65 | + h xW = tape_matmul(x_id, W_mod); |
| 66 | + return tape_add(xW, b); |
| 67 | +} |
| 68 | + |
| 69 | +fn build_model(arm, vocab_size, d_model, ff_dim, seq_len, seed) { |
| 70 | + h emb = prom_embedding_new(vocab_size, d_model, seed); |
| 71 | + h s1 = dict_get(emb, "rng_state"); |
| 72 | + h attn = prom_attention_substrate_k_new(d_model, seq_len, s1 + 11); |
| 73 | + h s2 = dict_get(attn, "rng_state"); |
| 74 | + h ln1 = prom_layernorm_new(d_model, s2); |
| 75 | + h ff_up = prom_linear_new(d_model, ff_dim, s2 + 13); |
| 76 | + h s3 = dict_get(ff_up, "rng_state"); |
| 77 | + h ff_down = prom_linear_new(ff_dim, d_model, s3); |
| 78 | + h s4 = dict_get(ff_down, "rng_state"); |
| 79 | + h ln2 = prom_layernorm_new(d_model, s4); |
| 80 | + h head = prom_linear_new(d_model, vocab_size, s4 + 17); |
| 81 | + h m = dict_new(); |
| 82 | + dict_set(m, "arm", arm); |
| 83 | + dict_set(m, "emb", emb); |
| 84 | + dict_set(m, "attn", attn); |
| 85 | + dict_set(m, "ln1", ln1); |
| 86 | + dict_set(m, "ff_up", ff_up); |
| 87 | + dict_set(m, "ff_down", ff_down); |
| 88 | + dict_set(m, "ln2", ln2); |
| 89 | + dict_set(m, "head", head); |
| 90 | + return m; |
| 91 | +} |
| 92 | + |
| 93 | +fn forward_window(model, token_ids, pe_table, alpha_now, scale_now) { |
| 94 | + h arm = dict_get(model, "arm"); |
| 95 | + h x = prom_embedding_batch(dict_get(model, "emb"), token_ids); |
| 96 | + h pe_rows = []; |
| 97 | + h i = 0; |
| 98 | + while i < arr_len(token_ids) { arr_push(pe_rows, arr_get(pe_table, i)); i = i + 1; } |
| 99 | + x = tape_add(x, tape_const(pe_rows)); |
| 100 | + |
| 101 | + h attn_out = null; |
| 102 | + if arm == "baseline" { |
| 103 | + attn_out = prom_attention_substrate_k_forward(dict_get(model, "attn"), x); |
| 104 | + } elif arm == "v0810ref" { |
| 105 | + attn_out = attn_forward_gm_qv(dict_get(model, "attn"), x, 64.0, 0.5); |
| 106 | + } elif arm == "R1_decay" { |
| 107 | + attn_out = attn_forward_gm_qv(dict_get(model, "attn"), x, 64.0, alpha_now); |
| 108 | + } elif arm == "R2_ff_only" { |
| 109 | + attn_out = prom_attention_substrate_k_forward(dict_get(model, "attn"), x); |
| 110 | + } elif arm == "R4_scale1024" { |
| 111 | + attn_out = attn_forward_gm_qv(dict_get(model, "attn"), x, 1024.0, 0.5); |
| 112 | + } |
| 113 | + |
| 114 | + h x_post = tape_add(x, attn_out); |
| 115 | + h n1 = prom_layernorm_forward(dict_get(model, "ln1"), x_post); |
| 116 | + |
| 117 | + h up = null; |
| 118 | + h down = null; |
| 119 | + if arm == "R2_ff_only" { |
| 120 | + # Apply substrate gm to FF up/down weights only. |
| 121 | + up = linear_forward_gm(dict_get(model, "ff_up"), n1, 64.0, 0.5); |
| 122 | + down = linear_forward_gm(dict_get(model, "ff_down"), prom_relu(up), 64.0, 0.5); |
| 123 | + } else { |
| 124 | + up = prom_linear_forward(dict_get(model, "ff_up"), n1); |
| 125 | + down = prom_linear_forward(dict_get(model, "ff_down"), prom_relu(up)); |
| 126 | + } |
| 127 | + h x_ff = tape_add(x_post, down); |
| 128 | + h n2 = prom_layernorm_forward(dict_get(model, "ln2"), x_ff); |
| 129 | + return prom_linear_forward(dict_get(model, "head"), n2); |
| 130 | +} |
| 131 | + |
| 132 | +fn collect_all(model) { |
| 133 | + h attn_p = prom_attention_substrate_k_params(dict_get(model, "attn")); |
| 134 | + h other = prom_collect_params_v2([ |
| 135 | + dict_get(model, "emb"), |
| 136 | + dict_get(model, "ln1"), |
| 137 | + dict_get(model, "ff_up"), |
| 138 | + dict_get(model, "ff_down"), |
| 139 | + dict_get(model, "ln2"), |
| 140 | + dict_get(model, "head"), |
| 141 | + ]); |
| 142 | + h out = []; |
| 143 | + h i = 0; |
| 144 | + while i < arr_len(attn_p) { arr_push(out, arr_get(attn_p, i)); i = i + 1; } |
| 145 | + i = 0; |
| 146 | + while i < arr_len(other) { arr_push(out, arr_get(other, i)); i = i + 1; } |
| 147 | + return out; |
| 148 | +} |
| 149 | + |
| 150 | +fn train(arm, vocab_size, ids, seq_len, d_model, ff_dim, lr, steps, seed) { |
| 151 | + tape_reset(); |
| 152 | + h model = build_model(arm, vocab_size, d_model, ff_dim, seq_len, seed); |
| 153 | + h params = collect_all(model); |
| 154 | + h opt = prom_adamw_new(params, lr, 0.9, 0.999, 1e-8, 0.0); |
| 155 | + h pe_table = prom_crt_pe_matrix(seq_len, d_model); |
| 156 | + h n_windows = arr_len(ids) - seq_len - 1; |
| 157 | + h tail = []; |
| 158 | + h step = 0; |
| 159 | + while step < steps { |
| 160 | + # alpha schedule for R1: linear decay 0.5 → 0.0 |
| 161 | + h frac = step * 1.0 / steps; |
| 162 | + h alpha_now = 0.5 * (1.0 - frac); |
| 163 | + h start = step - (step / n_windows) * n_windows; |
| 164 | + h window = []; |
| 165 | + h targets = []; |
| 166 | + h k = 0; |
| 167 | + while k < seq_len { |
| 168 | + arr_push(window, arr_get(ids, start + k)); |
| 169 | + arr_push(targets, arr_get(ids, start + k + 1)); |
| 170 | + k = k + 1; |
| 171 | + } |
| 172 | + h logits = forward_window(model, window, pe_table, alpha_now, 64.0); |
| 173 | + h loss = prom_cross_entropy_batch(logits, targets, vocab_size); |
| 174 | + tape_backward(loss); |
| 175 | + prom_adamw_step(opt); |
| 176 | + if step >= steps - 30 { arr_push(tail, tape_value(loss)); } |
| 177 | + step = step + 1; |
| 178 | + } |
| 179 | + h s = 0.0; h i = 0; |
| 180 | + while i < arr_len(tail) { s = s + arr_get(tail, i); i = i + 1; } |
| 181 | + return s / arr_len(tail); |
| 182 | +} |
| 183 | + |
| 184 | +fn mean_arr(xs) { |
| 185 | + h s = 0.0; h i = 0; |
| 186 | + while i < arr_len(xs) { s = s + arr_get(xs, i); i = i + 1; } |
| 187 | + return s / arr_len(xs); |
| 188 | +} |
| 189 | + |
| 190 | +fn main() { |
| 191 | + print("=== v0.8.11 substrate-grad-mod reformulations ==="); |
| 192 | + h text = "the rain in spain falls mainly on the plain and the sun rises in the east while the moon hides behind the mountain peaks of distant lands"; |
| 193 | + h vocab = build_vocab(text); |
| 194 | + h vocab_size = arr_len(dict_get(vocab, "chars")); |
| 195 | + h ids = encode(text, vocab); |
| 196 | + h seq_len = 16; |
| 197 | + h d_model = 32; |
| 198 | + h ff_dim = 64; |
| 199 | + h lr = 0.005; |
| 200 | + h steps = 250; |
| 201 | + h seeds = [42, 7, 123]; |
| 202 | + |
| 203 | + print(concat_many("d_model=", to_string(d_model), |
| 204 | + " steps=", to_string(steps), |
| 205 | + " seeds=", to_string(arr_len(seeds)))); |
| 206 | + print(""); |
| 207 | + |
| 208 | + h arms = ["baseline", "v0810ref", "R1_decay", "R2_ff_only", "R4_scale1024"]; |
| 209 | + h labels = dict_new(); |
| 210 | + dict_set(labels, "baseline", "baseline (no gm) "); |
| 211 | + dict_set(labels, "v0810ref", "v0810 ref (gm Q/V α=0.5) "); |
| 212 | + dict_set(labels, "R1_decay", "R1 decay α 0.5→0 "); |
| 213 | + dict_set(labels, "R2_ff_only", "R2 FF only "); |
| 214 | + dict_set(labels, "R4_scale1024", "R4 scale=1024 (finer) "); |
| 215 | + |
| 216 | + h results = dict_new(); |
| 217 | + h ai = 0; |
| 218 | + while ai < arr_len(arms) { |
| 219 | + h arm = arr_get(arms, ai); |
| 220 | + h losses = []; |
| 221 | + h si = 0; |
| 222 | + while si < arr_len(seeds) { |
| 223 | + h seed = arr_get(seeds, si); |
| 224 | + h L = train(arm, vocab_size, ids, seq_len, d_model, ff_dim, lr, steps, seed); |
| 225 | + arr_push(losses, L); |
| 226 | + si = si + 1; |
| 227 | + } |
| 228 | + dict_set(results, arm, losses); |
| 229 | + h mu = mean_arr(losses); |
| 230 | + print(concat_many(dict_get(labels, arm), " mean=", to_string(mu))); |
| 231 | + ai = ai + 1; |
| 232 | + } |
| 233 | + |
| 234 | + print(""); |
| 235 | + print("=== headline ==="); |
| 236 | + h base_mu = mean_arr(dict_get(results, "baseline")); |
| 237 | + ai = 0; |
| 238 | + while ai < arr_len(arms) { |
| 239 | + h arm = arr_get(arms, ai); |
| 240 | + h mu = mean_arr(dict_get(results, arm)); |
| 241 | + h delta = mu - base_mu; |
| 242 | + h pct = (delta / base_mu) * 100.0; |
| 243 | + h wins = 0; |
| 244 | + h si = 0; |
| 245 | + while si < arr_len(seeds) { |
| 246 | + if arr_get(dict_get(results, arm), si) < arr_get(dict_get(results, "baseline"), si) { |
| 247 | + wins = wins + 1; |
| 248 | + } |
| 249 | + si = si + 1; |
| 250 | + } |
| 251 | + print(concat_many(dict_get(labels, arm), |
| 252 | + " mean=", to_string(mu), |
| 253 | + " Δ=", to_string(delta), |
| 254 | + " (", to_string(pct), "%)", |
| 255 | + " wins ", to_string(wins), "/", to_string(arr_len(seeds)))); |
| 256 | + ai = ai + 1; |
| 257 | + } |
| 258 | +} |
| 259 | + |
| 260 | +main(); |
0 commit comments