Skip to content

Commit 34f61fa

Browse files
v0.8.5 substrate ops + multi-head substrate-K (items #1, #2, #4, #5, #6)
Five v0.8.5 items shipped: #1 tape_cross_entropy_batch — fused softmax + select-target-log + mean, one tape node, closed-form (p - one_hot)/N backward. Eliminates the 5 intermediate tape nodes the OMC composition built per call. #2 tape_embedding_lookup — direct row gather; replaces the [N, vocab] one-hot construction + matmul in prom_embedding_batch. Backward scatters rows of dy into the table's row gradients. #4 OMC_VM=1 negative finding — measured at d_model=256: 0.662 s/step (was 0.661 tree-walk). No win once hot paths are in Rust builtins. Won't pursue VM further for Prometheus. #5 prom_attention_substrate_k_mh_* — multi-head substrate-K attention. Math-equivalent "sum of per-head W_O projections" form (avoids needing a tape_concat op). Per-head Q_h, V_h, W_O_h tape vars plus per-head CRT-PE constant of width d_head. All single-head toggles (smod_alpha, v_resample_scale, q6_mode) honored per-head with same defaults. Cross-validation at d_model=32, 4 heads (d_head=8), 400 steps, 3 seeds: SH (single head) mean=2.0047 MH (4 heads) mean=1.9998 Δ=-0.25% wins 2/3 Multi-head IS beating single-head in OMC, directionally consistent with PyTorch L1-MH -8.94%. Effect grows with capacity. #6 tape_substrate_resample — fused tape op for substrate-V resample. Skips tape_value → modulator_matrix → tape_const round-trip (which extracted 16k f64s at d_model=256 seq_len=64 per call). Pairs with substrate_resample_matrix builtin from v0.8.4; same math. Wall-clock at d_model=256 essentially unchanged (was already AdamW- bound by v0.8.4); these wins materialize at larger vocab + larger matmul shapes + with multi-head capacity. Loss agreement preserved (6.95930) across all changes. 1111/1111 OMC tests pass. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1 parent 9638cdf commit 34f61fa

3 files changed

Lines changed: 564 additions & 48 deletions

File tree

examples/lib/prometheus.omc

Lines changed: 131 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -727,11 +727,11 @@ fn _prom_substrate_resample_matrix(v_val, scale) {
727727
# flows through v unchanged (modulation rides as a const). scale=0.0
728728
# disables (returns v unchanged).
729729
fn prom_substrate_resample(v_id, scale) {
730-
if scale == 0.0 { return v_id; }
731-
h v_val = tape_value(v_id);
732-
h mod_mat = _prom_substrate_resample_matrix(v_val, scale);
733-
h mod_const = tape_const(mod_mat);
734-
return tape_mul(v_id, mod_const);
730+
# v0.8.5 — defers to the fused tape_substrate_resample Rust builtin.
731+
# Skips the tape_value → modulator_matrix → tape_const round-trip
732+
# the old composition did (16k f64 cells at d_model=256 seq_len=64
733+
# being extracted and re-lifted per call).
734+
return tape_substrate_resample(v_id, scale);
735735
}
736736

737737
# Substrate-modulated softmax. alpha=0.0 returns standard softmax.
@@ -936,6 +936,121 @@ fn prom_attention_substrate_full_forward(layer, x_id) {
936936
return tape_matmul(attn, x_id);
937937
}
938938

939+
# ---------------------------------------------------------------------------
940+
# Multi-head substrate-K attention.
941+
#
942+
# PyTorch's L1-MH (-8.94% val) and Q6-MH (-12.15% val) findings need
943+
# multi-head capacity to fire — single-head OMC saw roughly a third of
944+
# those wins (SUBSTRATE_STACK_OMC_XVAL.md). This wraps n_heads independent
945+
# substrate-K heads with per-head W_O projections. Their outputs SUM to
946+
# the final attention output, which is mathematically identical to the
947+
# standard "concat then project" formulation:
948+
#
949+
# standard: out = concat(out_1, ..., out_h) @ W_O
950+
# W_O is [n_h·d_head, d_model] = block-stack of [W_O_1; ...; W_O_h]
951+
# so: out = sum_h(out_h @ W_O_h) ← what we use
952+
#
953+
# Equivalent math, no tape_concat op needed. Each head gets its own
954+
# Q_h, V_h, W_O_h, and per-head substrate K (CRT-PE of width d_head).
955+
# All the single-head toggles (smod_alpha, v_resample_scale, q6_mode)
956+
# are honored per-head; defaults match the single-head layer.
957+
# ---------------------------------------------------------------------------
958+
959+
fn prom_attention_substrate_k_mh_new(d_model, seq_len, n_heads, rng_state) {
960+
if d_model - (d_model / n_heads) * n_heads != 0 {
961+
# d_model must be divisible by n_heads.
962+
# (OMC has no error-throw primitive here; fall through and the
963+
# smaller-than-expected d_head will produce a shape error later.)
964+
}
965+
h d_head = d_model / n_heads;
966+
h heads = [];
967+
h state = rng_state;
968+
h hi = 0;
969+
while hi < n_heads {
970+
h Q = _prom_random_matrix(d_model, d_head, 0.3, state);
971+
state = dict_get(Q, "state");
972+
h V = _prom_random_matrix(d_model, d_head, 0.3, state);
973+
state = dict_get(V, "state");
974+
h W_O = _prom_random_matrix(d_head, d_model, 0.3, state);
975+
state = dict_get(W_O, "state");
976+
h head = dict_new();
977+
dict_set(head, "Q", dict_get(Q, "node"));
978+
dict_set(head, "V", dict_get(V, "node"));
979+
dict_set(head, "W_O", dict_get(W_O, "node"));
980+
dict_set(head, "K_const", prom_crt_pe_matrix(seq_len, d_head));
981+
arr_push(heads, head);
982+
hi = hi + 1;
983+
}
984+
h layer = dict_new();
985+
dict_set(layer, "kind", "attention");
986+
dict_set(layer, "variant", "substrate_k_mh");
987+
dict_set(layer, "d_model", d_model);
988+
dict_set(layer, "d_head", d_head);
989+
dict_set(layer, "n_heads", n_heads);
990+
dict_set(layer, "seq_len", seq_len);
991+
dict_set(layer, "heads", heads);
992+
# Per-head toggles match the single-head substrate_k_new defaults.
993+
dict_set(layer, "smod_alpha", 1.0);
994+
dict_set(layer, "v_resample_scale", 10.0);
995+
dict_set(layer, "q6_mode", "off");
996+
dict_set(layer, "q6_scale", 10.0);
997+
dict_set(layer, "q6_gamma", 0.5);
998+
dict_set(layer, "rng_state", state);
999+
return layer;
1000+
}
1001+
1002+
fn prom_attention_substrate_k_mh_forward(layer, x_id) {
1003+
h n_heads = dict_get(layer, "n_heads");
1004+
h heads = dict_get(layer, "heads");
1005+
h smod_alpha = dict_get(layer, "smod_alpha");
1006+
h v_scale = dict_get(layer, "v_resample_scale");
1007+
h q6_mode = dict_get(layer, "q6_mode");
1008+
h q6_scale = dict_get(layer, "q6_scale");
1009+
h q6_gamma = dict_get(layer, "q6_gamma");
1010+
1011+
h sum_proj = null;
1012+
h hi = 0;
1013+
while hi < n_heads {
1014+
h head = arr_get(heads, hi);
1015+
h Q_w = dict_get(head, "Q");
1016+
h V_w = dict_get(head, "V");
1017+
h W_O = dict_get(head, "W_O");
1018+
h K_const = dict_get(head, "K_const");
1019+
1020+
h q = tape_matmul(x_id, Q_w); # [N, d_head]
1021+
h q_mod = prom_q6_modulate(q, q6_scale, q6_gamma, q6_mode);
1022+
h v_raw = tape_matmul(x_id, V_w); # [N, d_head]
1023+
h v = prom_substrate_resample(v_raw, v_scale);
1024+
h k = tape_const(K_const);
1025+
h kt = tape_transpose(k);
1026+
h scores = tape_matmul(q_mod, kt); # [N, seq_len]
1027+
h attn = prom_substrate_softmax(scores, smod_alpha);
1028+
h out_h = tape_matmul(attn, v); # [N, d_head]
1029+
h proj_h = tape_matmul(out_h, W_O); # [N, d_model]
1030+
if sum_proj == null {
1031+
sum_proj = proj_h;
1032+
} else {
1033+
sum_proj = tape_add(sum_proj, proj_h);
1034+
}
1035+
hi = hi + 1;
1036+
}
1037+
return sum_proj;
1038+
}
1039+
1040+
fn prom_attention_substrate_k_mh_params(layer) {
1041+
h heads = dict_get(layer, "heads");
1042+
h out = [];
1043+
h hi = 0;
1044+
while hi < arr_len(heads) {
1045+
h head = arr_get(heads, hi);
1046+
arr_push(out, dict_get(head, "Q"));
1047+
arr_push(out, dict_get(head, "V"));
1048+
arr_push(out, dict_get(head, "W_O"));
1049+
hi = hi + 1;
1050+
}
1051+
return out;
1052+
}
1053+
9391054
# Param collectors per variant.
9401055
fn prom_attention_substrate_k_params(layer) {
9411056
return [dict_get(layer, "Q"), dict_get(layer, "V")];
@@ -1141,55 +1256,23 @@ fn prom_embedding_params(layer) {
11411256
# Implemented via an [N, vocab] one-hot batch then matmul with the
11421257
# embedding table. Differentiable end-to-end.
11431258
fn prom_embedding_batch(layer, token_ids) {
1144-
h vocab = dict_get(layer, "vocab");
1259+
# v0.8.5 — defers to the tape_embedding_lookup Rust builtin. Direct
1260+
# row gather instead of building an [N, vocab] one-hot in OMC and
1261+
# matmulling. Backward scatters dL/dout rows back into table grad,
1262+
# which is the same gradient as the one-hot @ table chain produced.
11451263
h table = dict_get(layer, "table");
1146-
h n = arr_len(token_ids);
1147-
h onehot = [];
1148-
h i = 0;
1149-
while i < n {
1150-
h row = [];
1151-
h idx = arr_get(token_ids, i);
1152-
h j = 0;
1153-
while j < vocab {
1154-
if j == idx { arr_push(row, 1.0); }
1155-
else { arr_push(row, 0.0); }
1156-
j = j + 1;
1157-
}
1158-
arr_push(onehot, row);
1159-
i = i + 1;
1160-
}
1161-
h onehot_const = tape_const(onehot);
1162-
return tape_matmul(onehot_const, table);
1264+
return tape_embedding_lookup(table, token_ids);
11631265
}
11641266

11651267
# Batched cross-entropy: logits is [N, vocab], targets is array of N
11661268
# integer indices. Returns scalar mean loss (averaged over positions).
1269+
#
1270+
# v0.8.5 — defers to the fused tape_cross_entropy_batch Rust builtin
1271+
# (closed-form (p - one_hot) / N backward, no intermediate tape nodes).
1272+
# `vocab` is accepted but unused (the builtin reads it from logits.cols);
1273+
# kept in the signature for callers that pass it.
11671274
fn prom_cross_entropy_batch(logits_id, targets, vocab) {
1168-
h n = arr_len(targets);
1169-
h probs = tape_softmax(logits_id);
1170-
h log_probs = tape_log(probs);
1171-
# Build [N, vocab] mask: -1.0 at (i, targets[i]), 0 elsewhere.
1172-
h mask_rows = [];
1173-
h i = 0;
1174-
while i < n {
1175-
h row = [];
1176-
h tgt = arr_get(targets, i);
1177-
h c = 0;
1178-
while c < vocab {
1179-
if c == tgt { arr_push(row, -1.0); }
1180-
else { arr_push(row, 0.0); }
1181-
c = c + 1;
1182-
}
1183-
arr_push(mask_rows, row);
1184-
i = i + 1;
1185-
}
1186-
h mask = tape_const(mask_rows);
1187-
h selected = tape_mul(log_probs, mask);
1188-
# Mean over all cells = (sum of -log p_target) / (N * vocab).
1189-
# We want per-token mean = sum / N. Use sum + divide.
1190-
h s = tape_sum(selected);
1191-
h scale = tape_const(1.0 / n);
1192-
return tape_mul(s, scale);
1275+
return tape_cross_entropy_batch(logits_id, targets);
11931276
}
11941277

11951278
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)