Skip to content

Commit bfb220b

Browse files
Merge pull request #4 from RandomCoder-lab/claude/find-claude-md-arn0F
transformerless_lm: omniweight loss — standard on training data
2 parents 9ddc081 + 22f3fd1 commit bfb220b

3 files changed

Lines changed: 679 additions & 244 deletions

File tree

experiments/transformerless_lm/losses_substrate.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,3 +344,95 @@ def substrate_fft_loss(logits: torch.Tensor, targets: torch.Tensor,
344344
tgt_sin = target_onehot @ basis_sin
345345
fft_mismatch = ((pred_cos - tgt_cos) ** 2 + (pred_sin - tgt_sin) ** 2).mean()
346346
return ce + lambda_substrate * fft_mismatch
347+
348+
349+
_PHI = (1.0 + 5.0 ** 0.5) / 2.0
350+
_PHI_PI = _PHI ** math.pi
351+
_LOG_PHI_PI = math.log(_PHI_PI)
352+
353+
354+
def substrate_omniweight_loss(logits: torch.Tensor, targets: torch.Tensor,
355+
vocab_size: int,
356+
lambda_substrate: float = 0.01,
357+
window: int = 21) -> torch.Tensor:
358+
"""CE weighted by the substrate omniweight ledger evaluated on targets.
359+
360+
Ports the inference-side omniweight standard (fluid form
361+
phi^pi * tanh(delta / phi^pi)) to the training loss. Each target
362+
token's CE contribution is multiplied by exp(fluid_delta) where
363+
fluid_delta is the substrate's verdict on that token at its
364+
position. Tokens the inference ledger would suppress (stagnating
365+
repetitions) get their training gradient muted by the same standard
366+
-- closes the train/inference omniweight asymmetry.
367+
368+
Minimum-surface port: only the anti-stagnation primitive contributes
369+
to the ledger here (Fibonacci-tier counts F(6)=8, F(7)=13, F(8)=21
370+
over the preceding window, matching substrate_anti_stagnation).
371+
All deltas pass through the same phi^pi * tanh standard so
372+
additional primitives can be added without architectural change.
373+
374+
Weights are renormalized so mean weight = 1, preserving loss scale.
375+
376+
Args:
377+
logits: [B, T, V]
378+
targets: [B, T]
379+
vocab_size: V
380+
lambda_substrate: weight on the FFT-spectrum term (matches
381+
substrate_fft_loss; the CE term is the omniweight-modulated one)
382+
window: anti-stagnation window in tokens (default F(8)=21)
383+
384+
Returns:
385+
scalar loss
386+
"""
387+
B, T = targets.shape
388+
V = vocab_size
389+
device = logits.device
390+
dtype = logits.dtype
391+
392+
# Per-position count of target[b,t] occurrences in targets[b, t-window:t].
393+
pos_idx = torch.arange(T, device=device)
394+
diff = pos_idx.unsqueeze(1) - pos_idx.unsqueeze(0) # [T, T]
395+
win_mask = ((diff > 0) & (diff <= window)).to(dtype) # [T, T]
396+
eq = (targets.unsqueeze(2) == targets.unsqueeze(1)).to(dtype) # [B, T, T]
397+
counts = (eq * win_mask.unsqueeze(0)).sum(dim=2) # [B, T]
398+
399+
# Anti-stagnation contribution to the ledger (matches inference thresholds:
400+
# count >= F(6)=8 -> divide by phi^pi -> delta = -log(phi^pi)
401+
# count >= F(7)=13 -> divide by phi^(2pi) -> delta = -2*log(phi^pi)
402+
# count >= F(8)=21 -> hard suppression -> delta = -4*log(phi^pi)
403+
# (the inference path sets prob=0 at F(8); here we let tanh saturate.)
404+
delta = torch.zeros_like(counts)
405+
m_8 = (counts >= 8.0) & (counts < 13.0)
406+
m_13 = (counts >= 13.0) & (counts < 21.0)
407+
m_21 = counts >= 21.0
408+
delta = torch.where(m_8, torch.full_like(delta, -_LOG_PHI_PI), delta)
409+
delta = torch.where(m_13, torch.full_like(delta, -2.0 * _LOG_PHI_PI), delta)
410+
delta = torch.where(m_21, torch.full_like(delta, -4.0 * _LOG_PHI_PI), delta)
411+
412+
# Fluid substrate standard: phi^pi * tanh(delta / phi^pi). Same form
413+
# the inference omniweight uses (_omniweight_apply).
414+
fluid_delta = _PHI_PI * torch.tanh(delta / _PHI_PI)
415+
weight = torch.exp(fluid_delta) # bounded in [exp(-phi^pi), 1]
416+
417+
# Per-token CE, weighted by the omniweight ledger.
418+
ce_per_tok = F.cross_entropy(
419+
logits.reshape(-1, V),
420+
targets.reshape(-1),
421+
reduction='none',
422+
).reshape(B, T)
423+
ce = (ce_per_tok * weight).sum() / (weight.sum() + 1e-8)
424+
425+
# Same FFT-spectrum substrate term as substrate_fft_loss.
426+
fib_freqs = torch.tensor([1, 2, 3, 5, 8, 13, 21], dtype=dtype, device=device)
427+
v_idx = torch.arange(vocab_size, dtype=dtype, device=device)
428+
angles = 2 * math.pi * v_idx.unsqueeze(1) * fib_freqs.unsqueeze(0) / vocab_size
429+
basis_cos = torch.cos(angles)
430+
basis_sin = torch.sin(angles)
431+
pred = F.softmax(logits, dim=-1)
432+
target_onehot = F.one_hot(targets, vocab_size).to(pred.dtype)
433+
pred_cos = pred @ basis_cos
434+
pred_sin = pred @ basis_sin
435+
tgt_cos = target_onehot @ basis_cos
436+
tgt_sin = target_onehot @ basis_sin
437+
fft_mismatch = ((pred_cos - tgt_cos) ** 2 + (pred_sin - tgt_sin) ** 2).mean()
438+
return ce + lambda_substrate * fft_mismatch

0 commit comments

Comments
 (0)