Skip to content

Commit 8d72769

Browse files
committed
transformerless_lm: split-brain omniweight (math + lang hemispheres)
Two separate omniweight registers: Math hemisphere (frequency/decay): substrate-sampling, recency, bigram, anti-stag, bigram-saturation Language hemisphere (purpose/structure): iambic, anaphora, need-fill, phonotactics, rhyme, agreement, word-spacing, char-cascade, pronounceability, subject-threading, theme-momentum Each hemisphere builds its own fluid delta via tanh-scaled substrate reserve phi^pi. Final distribution = geometric mean of the two (sqrt(p_math * p_lang) / Z). A token survives only if both hemispheres consent (Bayesian Product of Experts). User-named "left/right brain" architecture. Math is the older substrate foundation; language is the newer purpose layer. Geometric mean is the substrate-canonical consensus mixer.
1 parent b107bb8 commit 8d72769

1 file changed

Lines changed: 96 additions & 56 deletions

File tree

experiments/transformerless_lm/train_self_recursive.py

Lines changed: 96 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1318,19 +1318,41 @@ def _omniweight_apply(base_probs: torch.Tensor,
13181318
"""Apply accumulated log-pressure via tanh-scaled substrate reserve.
13191319
13201320
fluid_delta = phi^pi * tanh(delta_acc / phi^pi)
1321-
1322-
Small contributions pass linear (tanh near origin ~ identity).
1323-
Large contributions saturate gracefully toward +/- phi^pi.
1324-
When primitives agree, deltas sum cleanly. When they disagree,
1325-
they cancel naturally within the sum.
1326-
1327-
Pure substrate (phi^pi as the reserve standard).
13281321
"""
13291322
fluid = _OMNIWEIGHT_RESERVE * torch.tanh(delta_acc / _OMNIWEIGHT_RESERVE)
13301323
out = base_probs * torch.exp(fluid)
13311324
return out / (out.sum() + 1e-8)
13321325

13331326

1327+
def _omniweight_apply_split(base_probs: torch.Tensor,
1328+
math_delta: torch.Tensor,
1329+
lang_delta: torch.Tensor) -> torch.Tensor:
1330+
"""SPLIT-BRAIN omniweight: two registers, geometric-mean mixer.
1331+
1332+
Math hemisphere: bigram, recency, substrate sampling, anti-stag,
1333+
bigram-saturation. Frequency / decay primitives.
1334+
1335+
Language hemisphere: iambic, anaphora, need-fill, phonotactics,
1336+
rhyme, agreement, word-spacing, char-cascade, pronunciation,
1337+
subject-threading, theme. Purpose / structure primitives.
1338+
1339+
Each hemisphere builds its own fluid delta via tanh-scaled
1340+
substrate reserve. Final distribution = geometric mean of the
1341+
two -- a token survives only if both hemispheres consent.
1342+
1343+
Pure substrate (phi^pi reserve, sqrt mixing = Bayesian PoE).
1344+
"""
1345+
math_fluid = _OMNIWEIGHT_RESERVE * torch.tanh(math_delta / _OMNIWEIGHT_RESERVE)
1346+
lang_fluid = _OMNIWEIGHT_RESERVE * torch.tanh(lang_delta / _OMNIWEIGHT_RESERVE)
1347+
p_math = base_probs * torch.exp(math_fluid)
1348+
p_lang = base_probs * torch.exp(lang_fluid)
1349+
p_math = p_math / (p_math.sum() + 1e-8)
1350+
p_lang = p_lang / (p_lang.sum() + 1e-8)
1351+
# Geometric mean (Bayesian product of experts).
1352+
p_final = torch.sqrt(p_math * p_lang)
1353+
return p_final / (p_final.sum() + 1e-8)
1354+
1355+
13341356
def autoregressive_generate(model, prompt: torch.Tensor, n_new: int,
13351357
vocab_size: int, temperature: float = 1.0,
13361358
substrate_sampling: bool = True,
@@ -1409,70 +1431,77 @@ def autoregressive_generate(model, prompt: torch.Tensor, n_new: int,
14091431
T = seq.shape[1]
14101432
ctx = seq if T <= model.seq_len else seq[:, -model.seq_len:]
14111433
logits = model(ctx)[:, -1, :] / temperature
1434+
# SPLIT-BRAIN: base = softmax(plain logits); recency &
1435+
# substrate-sampling become math omniweight contributors.
1436+
base = F.softmax(logits[0], dim=-1)
1437+
math_delta = torch.zeros_like(base)
1438+
lang_delta = torch.zeros_like(base)
1439+
# ---- Math hemisphere ----
14121440
if recency_penalty:
14131441
history_t = seq[0, -recency_window:]
1414-
logits[0] = substrate_recency_penalty(
1442+
rec_logits = substrate_recency_penalty(
14151443
history_t, logits[0], vocab_size)
1444+
p = F.softmax(rec_logits, dim=-1)
1445+
math_delta += _omniweight_delta(base, p)
14161446
if substrate_sampling:
1417-
probs = F.softmax(logits * _PI_LOG_PHI, dim=-1)
1418-
else:
1419-
probs = F.softmax(logits, dim=-1)
1420-
# OMNIWEIGHT: every primitive contributes delta_log_p to a
1421-
# shared accumulator. Total clamped, applied once.
1422-
base = probs[0]
1423-
delta_acc = torch.zeros_like(base)
1447+
p = F.softmax(logits[0] * _PI_LOG_PHI, dim=-1)
1448+
math_delta += _omniweight_delta(base, p)
14241449
if bigram_prior is not None and seq.shape[1] >= 1:
14251450
ctx_back = seq[0, -7:].tolist()
14261451
p = substrate_syntax_blend(
14271452
int(seq[0, -1]), bigram_prior, base,
14281453
context_tokens=ctx_back, vocab=vocab)
1429-
delta_acc += _omniweight_delta(base, p)
1454+
math_delta += _omniweight_delta(base, p)
1455+
if seq.shape[1] >= 1:
1456+
p = substrate_bigram_saturation(
1457+
int(seq[0, -1]), recent_pairs, base)
1458+
math_delta += _omniweight_delta(base, p)
1459+
history_aw = seq[0, -21:]
1460+
p = substrate_anti_stagnation(history_aw, base, vocab_size)
1461+
math_delta += _omniweight_delta(base, p)
1462+
# ---- Language hemisphere ----
14301463
p = substrate_iambic_phase(
14311464
syl_pos, base, vocab_size, newline_mask=newline_mask)
1432-
delta_acc += _omniweight_delta(base, p)
1465+
lang_delta += _omniweight_delta(base, p)
14331466
if pronoun_mask is not None and seq.shape[1] >= 1:
14341467
recent_list = seq[0, -13:].tolist()
14351468
p = substrate_reference_chain(
14361469
recent_list, pronoun_mask, base)
1437-
delta_acc += _omniweight_delta(base, p)
1470+
lang_delta += _omniweight_delta(base, p)
14381471
if open_needs > 0:
14391472
p = substrate_need_fill(
14401473
open_needs, base, vocab_size, punct_mask=punct_mask)
1441-
delta_acc += _omniweight_delta(base, p)
1474+
lang_delta += _omniweight_delta(base, p)
14421475
if vowel_start_mask is not None and cluster_len >= 2:
14431476
p = substrate_phonotactics(
14441477
cluster_len, base, vowel_start_mask)
1445-
delta_acc += _omniweight_delta(base, p)
1478+
lang_delta += _omniweight_delta(base, p)
14461479
if end_vowels is not None and seq.shape[1] >= 1:
14471480
recent_list = seq[0, -13:].tolist()
14481481
p = substrate_rhyme_resonance(
14491482
recent_list, end_vowels, base)
1450-
delta_acc += _omniweight_delta(base, p)
1451-
if seq.shape[1] >= 1:
1452-
p = substrate_bigram_saturation(
1453-
int(seq[0, -1]), recent_pairs, base)
1454-
delta_acc += _omniweight_delta(base, p)
1483+
lang_delta += _omniweight_delta(base, p)
14551484
if vocab is not None:
14561485
p = substrate_agreement(
14571486
last_content_ends_s, base, vocab)
1458-
delta_acc += _omniweight_delta(base, p)
1487+
lang_delta += _omniweight_delta(base, p)
14591488
if vocab is not None and seq.shape[1] >= 1:
14601489
p = substrate_word_spacing(
14611490
int(seq[0, -1]), base, vocab, n_chars=n_chars_local)
1462-
delta_acc += _omniweight_delta(base, p)
1491+
lang_delta += _omniweight_delta(base, p)
14631492
if char_run >= _FIB_NUMS_FOR_BIGRAM[3]:
14641493
p = substrate_char_cascade(
14651494
char_run, base, n_chars_local)
1466-
delta_acc += _omniweight_delta(base, p)
1495+
lang_delta += _omniweight_delta(base, p)
14671496
if unpronounceable_mask is not None:
14681497
p = substrate_pronounceability(
14691498
base, unpronounceable_mask)
1470-
delta_acc += _omniweight_delta(base, p)
1499+
lang_delta += _omniweight_delta(base, p)
14711500
if token_signatures is not None and seq.shape[1] >= 1:
14721501
recent_list = seq[0, -13:].tolist()
14731502
p = substrate_theme_momentum(
14741503
recent_list, token_signatures, base)
1475-
delta_acc += _omniweight_delta(base, p)
1504+
lang_delta += _omniweight_delta(base, p)
14761505
if vocab is not None and seq.shape[1] >= 1:
14771506
prev_tok_id = int(seq[0, -1])
14781507
prev_str = (vocab[prev_tok_id]
@@ -1481,12 +1510,10 @@ def autoregressive_generate(model, prompt: torch.Tensor, n_new: int,
14811510
seq_list = seq[0].tolist()
14821511
p = substrate_subject_threading(
14831512
seq_list, vocab, base, is_sentence_start=True)
1484-
delta_acc += _omniweight_delta(base, p)
1485-
history_aw = seq[0, -21:]
1486-
p = substrate_anti_stagnation(history_aw, base, vocab_size)
1487-
delta_acc += _omniweight_delta(base, p)
1488-
# Apply accumulated omniweight pressure (clamped).
1489-
probs[0] = _omniweight_apply(base, delta_acc)
1513+
lang_delta += _omniweight_delta(base, p)
1514+
# Apply split-brain mixer (geometric mean).
1515+
probs = _omniweight_apply_split(
1516+
base, math_delta, lang_delta).unsqueeze(0)
14901517
# Vocab curriculum (HARD mask, post-omniweight).
14911518
if active_vocab_size is not None:
14921519
probs[0] = substrate_vocab_curriculum(
@@ -1591,18 +1618,27 @@ def _single_stage_refine(model, draft, vocab_size, scorer, mode: str,
15911618
if t_draft < new.shape[1] and t_draft >= prompt_len:
15921619
start = max(0, t_draft - recency_window)
15931620
history_t = new[0, start:t_draft]
1594-
pos_logits = substrate_recency_penalty(
1621+
base_probs = F.softmax(logits[0, idx] / temperature, dim=-1)
1622+
# SPLIT-BRAIN: math + lang accumulators.
1623+
math_delta = torch.zeros_like(base_probs)
1624+
lang_delta = torch.zeros_like(base_probs)
1625+
# ---- Math hemisphere ----
1626+
# Recency penalty.
1627+
rec_logits = substrate_recency_penalty(
15951628
history_t, logits[0, idx], vocab_size_local)
1596-
base_probs = F.softmax(pos_logits / temperature, dim=-1)
1597-
# OMNIWEIGHT accumulator.
1598-
delta_acc = torch.zeros_like(base_probs)
1629+
p = F.softmax(rec_logits / temperature, dim=-1)
1630+
math_delta += _omniweight_delta(base_probs, p)
1631+
# Substrate sampling (phi^pi sharpening).
1632+
p = F.softmax(logits[0, idx] * _PI_LOG_PHI, dim=-1)
1633+
math_delta += _omniweight_delta(base_probs, p)
15991634
if bigram_prior is not None and t_draft >= 1:
16001635
ctx_back_start = max(0, t_draft - 7)
16011636
ctx_back = new[0, ctx_back_start:t_draft].tolist()
16021637
p = substrate_syntax_blend(
16031638
int(new[0, t_draft - 1]), bigram_prior, base_probs,
16041639
context_tokens=ctx_back, vocab=vocab)
1605-
delta_acc += _omniweight_delta(base_probs, p)
1640+
math_delta += _omniweight_delta(base_probs, p)
1641+
# ---- Language hemisphere ----
16061642
if vocab is not None:
16071643
syl_pos = 0
16081644
for tid in new[0, :t_draft].tolist():
@@ -1611,13 +1647,13 @@ def _single_stage_refine(model, draft, vocab_size, scorer, mode: str,
16111647
p = substrate_iambic_phase(
16121648
syl_pos, base_probs, vocab_size_local,
16131649
newline_mask=newline_mask)
1614-
delta_acc += _omniweight_delta(base_probs, p)
1650+
lang_delta += _omniweight_delta(base_probs, p)
16151651
if pronoun_mask is not None and t_draft >= 1:
16161652
recent_start = max(0, t_draft - 13)
16171653
recent_list = new[0, recent_start:t_draft].tolist()
16181654
p = substrate_reference_chain(
16191655
recent_list, pronoun_mask, base_probs)
1620-
delta_acc += _omniweight_delta(base_probs, p)
1656+
lang_delta += _omniweight_delta(base_probs, p)
16211657
# State-dependent primitives: compute from prefix.
16221658
n_chars_r = sum(1 for t in vocab if len(t) == 1) if vocab else 65
16231659
ct = n_chars_r + _FIB_NUMS_FOR_BIGRAM[7]
@@ -1659,45 +1695,48 @@ def _single_stage_refine(model, draft, vocab_size, scorer, mode: str,
16591695
if j > 0:
16601696
rp.append((int(new[0, j-1].item()), tid))
16611697
rp = rp[-13:]
1698+
# Language hemisphere primitives.
16621699
if op_needs > 0:
16631700
p = substrate_need_fill(
16641701
op_needs, base_probs, vocab_size_local,
16651702
punct_mask=punct_mask)
1666-
delta_acc += _omniweight_delta(base_probs, p)
1703+
lang_delta += _omniweight_delta(base_probs, p)
16671704
if vowel_start_mask is not None and cl_len >= 2:
16681705
p = substrate_phonotactics(
16691706
cl_len, base_probs, vowel_start_mask)
1670-
delta_acc += _omniweight_delta(base_probs, p)
1707+
lang_delta += _omniweight_delta(base_probs, p)
1708+
# Math hemisphere primitives.
16711709
p = substrate_bigram_saturation(
16721710
int(new[0, t_draft - 1]), rp, base_probs)
1673-
delta_acc += _omniweight_delta(base_probs, p)
1711+
math_delta += _omniweight_delta(base_probs, p)
1712+
# Language hemisphere.
16741713
p = substrate_agreement(
16751714
last_s_r, base_probs, vocab)
1676-
delta_acc += _omniweight_delta(base_probs, p)
1715+
lang_delta += _omniweight_delta(base_probs, p)
16771716
p = substrate_word_spacing(
16781717
int(new[0, t_draft - 1]), base_probs, vocab,
16791718
n_chars=n_chars_r)
1680-
delta_acc += _omniweight_delta(base_probs, p)
1719+
lang_delta += _omniweight_delta(base_probs, p)
16811720
if char_run_r >= _FIB_NUMS_FOR_BIGRAM[3]:
16821721
p = substrate_char_cascade(
16831722
char_run_r, base_probs, n_chars_r)
1684-
delta_acc += _omniweight_delta(base_probs, p)
1723+
lang_delta += _omniweight_delta(base_probs, p)
16851724
if unpronounceable_mask is not None:
16861725
p = substrate_pronounceability(
16871726
base_probs, unpronounceable_mask)
1688-
delta_acc += _omniweight_delta(base_probs, p)
1727+
lang_delta += _omniweight_delta(base_probs, p)
16891728
if end_vowels is not None:
16901729
recent_start_ev = max(0, t_draft - 13)
16911730
recent_list_ev = new[0, recent_start_ev:t_draft].tolist()
16921731
p = substrate_rhyme_resonance(
16931732
recent_list_ev, end_vowels, base_probs)
1694-
delta_acc += _omniweight_delta(base_probs, p)
1733+
lang_delta += _omniweight_delta(base_probs, p)
16951734
if token_signatures is not None and t_draft >= 1:
16961735
recent_start = max(0, t_draft - 13)
16971736
recent_list = new[0, recent_start:t_draft].tolist()
16981737
p = substrate_theme_momentum(
16991738
recent_list, token_signatures, base_probs)
1700-
delta_acc += _omniweight_delta(base_probs, p)
1739+
lang_delta += _omniweight_delta(base_probs, p)
17011740
if vocab is not None and t_draft >= 1:
17021741
prev_tok_id = int(new[0, t_draft - 1])
17031742
prev_str = (vocab[prev_tok_id]
@@ -1707,14 +1746,15 @@ def _single_stage_refine(model, draft, vocab_size, scorer, mode: str,
17071746
p = substrate_subject_threading(
17081747
seq_list, vocab, base_probs,
17091748
is_sentence_start=True)
1710-
delta_acc += _omniweight_delta(base_probs, p)
1749+
lang_delta += _omniweight_delta(base_probs, p)
17111750
aw_start = max(0, t_draft - 21)
17121751
history_aw = new[0, aw_start:t_draft]
17131752
p = substrate_anti_stagnation(
17141753
history_aw, base_probs, vocab_size_local)
1715-
delta_acc += _omniweight_delta(base_probs, p)
1716-
# Apply omniweight pressure.
1717-
pos_probs = _omniweight_apply(base_probs, delta_acc)
1754+
math_delta += _omniweight_delta(base_probs, p)
1755+
# Apply split-brain mixer (geometric mean).
1756+
pos_probs = _omniweight_apply_split(
1757+
base_probs, math_delta, lang_delta)
17181758
# Vocab curriculum (HARD mask, post-omniweight).
17191759
if active_vocab_size is not None:
17201760
pos_probs = substrate_vocab_curriculum(

0 commit comments

Comments
 (0)