@@ -727,11 +727,11 @@ fn _prom_substrate_resample_matrix(v_val, scale) {
727727# flows through v unchanged (modulation rides as a const). scale=0.0
728728# disables (returns v unchanged).
729729fn prom_substrate_resample(v_id, scale) {
730- if scale == 0.0 { return v_id; }
731- h v_val = tape_value(v_id);
732- h mod_mat = _prom_substrate_resample_matrix(v_val, scale);
733- h mod_const = tape_const(mod_mat);
734- return tape_mul (v_id, mod_const );
730+ # v0.8.5 — defers to the fused tape_substrate_resample Rust builtin.
731+ # Skips the tape_value → modulator_matrix → tape_const round-trip
732+ # the old composition did (16k f64 cells at d_model=256 seq_len=64
733+ # being extracted and re-lifted per call).
734+ return tape_substrate_resample (v_id, scale );
735735}
736736
737737# Substrate-modulated softmax. alpha=0.0 returns standard softmax.
@@ -936,6 +936,121 @@ fn prom_attention_substrate_full_forward(layer, x_id) {
936936 return tape_matmul(attn, x_id);
937937}
938938
939+ # ---------------------------------------------------------------------------
940+ # Multi-head substrate-K attention.
941+ #
942+ # PyTorch's L1-MH (-8.94% val) and Q6-MH (-12.15% val) findings need
943+ # multi-head capacity to fire — single-head OMC saw roughly a third of
944+ # those wins (SUBSTRATE_STACK_OMC_XVAL.md). This wraps n_heads independent
945+ # substrate-K heads with per-head W_O projections. Their outputs SUM to
946+ # the final attention output, which is mathematically identical to the
947+ # standard "concat then project" formulation:
948+ #
949+ # standard: out = concat(out_1, ..., out_h) @ W_O
950+ # W_O is [n_h·d_head, d_model] = block-stack of [W_O_1; ...; W_O_h]
951+ # so: out = sum_h(out_h @ W_O_h) ← what we use
952+ #
953+ # Equivalent math, no tape_concat op needed. Each head gets its own
954+ # Q_h, V_h, W_O_h, and per-head substrate K (CRT-PE of width d_head).
955+ # All the single-head toggles (smod_alpha, v_resample_scale, q6_mode)
956+ # are honored per-head; defaults match the single-head layer.
957+ # ---------------------------------------------------------------------------
958+
959+ fn prom_attention_substrate_k_mh_new(d_model, seq_len, n_heads, rng_state) {
960+ if d_model - (d_model / n_heads) * n_heads != 0 {
961+ # d_model must be divisible by n_heads.
962+ # (OMC has no error-throw primitive here; fall through and the
963+ # smaller-than-expected d_head will produce a shape error later.)
964+ }
965+ h d_head = d_model / n_heads;
966+ h heads = [];
967+ h state = rng_state;
968+ h hi = 0;
969+ while hi < n_heads {
970+ h Q = _prom_random_matrix(d_model, d_head, 0.3, state);
971+ state = dict_get(Q, "state");
972+ h V = _prom_random_matrix(d_model, d_head, 0.3, state);
973+ state = dict_get(V, "state");
974+ h W_O = _prom_random_matrix(d_head, d_model, 0.3, state);
975+ state = dict_get(W_O, "state");
976+ h head = dict_new();
977+ dict_set(head, "Q", dict_get(Q, "node"));
978+ dict_set(head, "V", dict_get(V, "node"));
979+ dict_set(head, "W_O", dict_get(W_O, "node"));
980+ dict_set(head, "K_const", prom_crt_pe_matrix(seq_len, d_head));
981+ arr_push(heads, head);
982+ hi = hi + 1;
983+ }
984+ h layer = dict_new();
985+ dict_set(layer, "kind", "attention");
986+ dict_set(layer, "variant", "substrate_k_mh");
987+ dict_set(layer, "d_model", d_model);
988+ dict_set(layer, "d_head", d_head);
989+ dict_set(layer, "n_heads", n_heads);
990+ dict_set(layer, "seq_len", seq_len);
991+ dict_set(layer, "heads", heads);
992+ # Per-head toggles match the single-head substrate_k_new defaults.
993+ dict_set(layer, "smod_alpha", 1.0);
994+ dict_set(layer, "v_resample_scale", 10.0);
995+ dict_set(layer, "q6_mode", "off");
996+ dict_set(layer, "q6_scale", 10.0);
997+ dict_set(layer, "q6_gamma", 0.5);
998+ dict_set(layer, "rng_state", state);
999+ return layer;
1000+ }
1001+
1002+ fn prom_attention_substrate_k_mh_forward(layer, x_id) {
1003+ h n_heads = dict_get(layer, "n_heads");
1004+ h heads = dict_get(layer, "heads");
1005+ h smod_alpha = dict_get(layer, "smod_alpha");
1006+ h v_scale = dict_get(layer, "v_resample_scale");
1007+ h q6_mode = dict_get(layer, "q6_mode");
1008+ h q6_scale = dict_get(layer, "q6_scale");
1009+ h q6_gamma = dict_get(layer, "q6_gamma");
1010+
1011+ h sum_proj = null;
1012+ h hi = 0;
1013+ while hi < n_heads {
1014+ h head = arr_get(heads, hi);
1015+ h Q_w = dict_get(head, "Q");
1016+ h V_w = dict_get(head, "V");
1017+ h W_O = dict_get(head, "W_O");
1018+ h K_const = dict_get(head, "K_const");
1019+
1020+ h q = tape_matmul(x_id, Q_w); # [N, d_head]
1021+ h q_mod = prom_q6_modulate(q, q6_scale, q6_gamma, q6_mode);
1022+ h v_raw = tape_matmul(x_id, V_w); # [N, d_head]
1023+ h v = prom_substrate_resample(v_raw, v_scale);
1024+ h k = tape_const(K_const);
1025+ h kt = tape_transpose(k);
1026+ h scores = tape_matmul(q_mod, kt); # [N, seq_len]
1027+ h attn = prom_substrate_softmax(scores, smod_alpha);
1028+ h out_h = tape_matmul(attn, v); # [N, d_head]
1029+ h proj_h = tape_matmul(out_h, W_O); # [N, d_model]
1030+ if sum_proj == null {
1031+ sum_proj = proj_h;
1032+ } else {
1033+ sum_proj = tape_add(sum_proj, proj_h);
1034+ }
1035+ hi = hi + 1;
1036+ }
1037+ return sum_proj;
1038+ }
1039+
1040+ fn prom_attention_substrate_k_mh_params(layer) {
1041+ h heads = dict_get(layer, "heads");
1042+ h out = [];
1043+ h hi = 0;
1044+ while hi < arr_len(heads) {
1045+ h head = arr_get(heads, hi);
1046+ arr_push(out, dict_get(head, "Q"));
1047+ arr_push(out, dict_get(head, "V"));
1048+ arr_push(out, dict_get(head, "W_O"));
1049+ hi = hi + 1;
1050+ }
1051+ return out;
1052+ }
1053+
9391054# Param collectors per variant.
9401055fn prom_attention_substrate_k_params(layer) {
9411056 return [dict_get(layer, "Q"), dict_get(layer, "V")];
@@ -1141,55 +1256,23 @@ fn prom_embedding_params(layer) {
11411256# Implemented via an [N, vocab] one-hot batch then matmul with the
11421257# embedding table. Differentiable end-to-end.
11431258fn prom_embedding_batch(layer, token_ids) {
1144- h vocab = dict_get(layer, "vocab");
1259+ # v0.8.5 — defers to the tape_embedding_lookup Rust builtin. Direct
1260+ # row gather instead of building an [N, vocab] one-hot in OMC and
1261+ # matmulling. Backward scatters dL/dout rows back into table grad,
1262+ # which is the same gradient as the one-hot @ table chain produced.
11451263 h table = dict_get(layer, "table");
1146- h n = arr_len(token_ids);
1147- h onehot = [];
1148- h i = 0;
1149- while i < n {
1150- h row = [];
1151- h idx = arr_get(token_ids, i);
1152- h j = 0;
1153- while j < vocab {
1154- if j == idx { arr_push(row, 1.0); }
1155- else { arr_push(row, 0.0); }
1156- j = j + 1;
1157- }
1158- arr_push(onehot, row);
1159- i = i + 1;
1160- }
1161- h onehot_const = tape_const(onehot);
1162- return tape_matmul(onehot_const, table);
1264+ return tape_embedding_lookup(table, token_ids);
11631265}
11641266
11651267# Batched cross-entropy: logits is [N, vocab], targets is array of N
11661268# integer indices. Returns scalar mean loss (averaged over positions).
1269+ #
1270+ # v0.8.5 — defers to the fused tape_cross_entropy_batch Rust builtin
1271+ # (closed-form (p - one_hot) / N backward, no intermediate tape nodes).
1272+ # `vocab` is accepted but unused (the builtin reads it from logits.cols);
1273+ # kept in the signature for callers that pass it.
11671274fn prom_cross_entropy_batch(logits_id, targets, vocab) {
1168- h n = arr_len(targets);
1169- h probs = tape_softmax(logits_id);
1170- h log_probs = tape_log(probs);
1171- # Build [N, vocab] mask: -1.0 at (i, targets[i]), 0 elsewhere.
1172- h mask_rows = [];
1173- h i = 0;
1174- while i < n {
1175- h row = [];
1176- h tgt = arr_get(targets, i);
1177- h c = 0;
1178- while c < vocab {
1179- if c == tgt { arr_push(row, -1.0); }
1180- else { arr_push(row, 0.0); }
1181- c = c + 1;
1182- }
1183- arr_push(mask_rows, row);
1184- i = i + 1;
1185- }
1186- h mask = tape_const(mask_rows);
1187- h selected = tape_mul(log_probs, mask);
1188- # Mean over all cells = (sum of -log p_target) / (N * vocab).
1189- # We want per-token mean = sum / N. Use sum + divide.
1190- h s = tape_sum(selected);
1191- h scale = tape_const(1.0 / n);
1192- return tape_mul(s, scale);
1275+ return tape_cross_entropy_batch(logits_id, targets);
11931276}
11941277
11951278# ---------------------------------------------------------------------------
0 commit comments