@@ -344,3 +344,95 @@ def substrate_fft_loss(logits: torch.Tensor, targets: torch.Tensor,
344344 tgt_sin = target_onehot @ basis_sin
345345 fft_mismatch = ((pred_cos - tgt_cos ) ** 2 + (pred_sin - tgt_sin ) ** 2 ).mean ()
346346 return ce + lambda_substrate * fft_mismatch
347+
348+
349+ _PHI = (1.0 + 5.0 ** 0.5 ) / 2.0
350+ _PHI_PI = _PHI ** math .pi
351+ _LOG_PHI_PI = math .log (_PHI_PI )
352+
353+
354+ def substrate_omniweight_loss (logits : torch .Tensor , targets : torch .Tensor ,
355+ vocab_size : int ,
356+ lambda_substrate : float = 0.01 ,
357+ window : int = 21 ) -> torch .Tensor :
358+ """CE weighted by the substrate omniweight ledger evaluated on targets.
359+
360+ Ports the inference-side omniweight standard (fluid form
361+ phi^pi * tanh(delta / phi^pi)) to the training loss. Each target
362+ token's CE contribution is multiplied by exp(fluid_delta) where
363+ fluid_delta is the substrate's verdict on that token at its
364+ position. Tokens the inference ledger would suppress (stagnating
365+ repetitions) get their training gradient muted by the same standard
366+ -- closes the train/inference omniweight asymmetry.
367+
368+ Minimum-surface port: only the anti-stagnation primitive contributes
369+ to the ledger here (Fibonacci-tier counts F(6)=8, F(7)=13, F(8)=21
370+ over the preceding window, matching substrate_anti_stagnation).
371+ All deltas pass through the same phi^pi * tanh standard so
372+ additional primitives can be added without architectural change.
373+
374+ Weights are renormalized so mean weight = 1, preserving loss scale.
375+
376+ Args:
377+ logits: [B, T, V]
378+ targets: [B, T]
379+ vocab_size: V
380+ lambda_substrate: weight on the FFT-spectrum term (matches
381+ substrate_fft_loss; the CE term is the omniweight-modulated one)
382+ window: anti-stagnation window in tokens (default F(8)=21)
383+
384+ Returns:
385+ scalar loss
386+ """
387+ B , T = targets .shape
388+ V = vocab_size
389+ device = logits .device
390+ dtype = logits .dtype
391+
392+ # Per-position count of target[b,t] occurrences in targets[b, t-window:t].
393+ pos_idx = torch .arange (T , device = device )
394+ diff = pos_idx .unsqueeze (1 ) - pos_idx .unsqueeze (0 ) # [T, T]
395+ win_mask = ((diff > 0 ) & (diff <= window )).to (dtype ) # [T, T]
396+ eq = (targets .unsqueeze (2 ) == targets .unsqueeze (1 )).to (dtype ) # [B, T, T]
397+ counts = (eq * win_mask .unsqueeze (0 )).sum (dim = 2 ) # [B, T]
398+
399+ # Anti-stagnation contribution to the ledger (matches inference thresholds:
400+ # count >= F(6)=8 -> divide by phi^pi -> delta = -log(phi^pi)
401+ # count >= F(7)=13 -> divide by phi^(2pi) -> delta = -2*log(phi^pi)
402+ # count >= F(8)=21 -> hard suppression -> delta = -4*log(phi^pi)
403+ # (the inference path sets prob=0 at F(8); here we let tanh saturate.)
404+ delta = torch .zeros_like (counts )
405+ m_8 = (counts >= 8.0 ) & (counts < 13.0 )
406+ m_13 = (counts >= 13.0 ) & (counts < 21.0 )
407+ m_21 = counts >= 21.0
408+ delta = torch .where (m_8 , torch .full_like (delta , - _LOG_PHI_PI ), delta )
409+ delta = torch .where (m_13 , torch .full_like (delta , - 2.0 * _LOG_PHI_PI ), delta )
410+ delta = torch .where (m_21 , torch .full_like (delta , - 4.0 * _LOG_PHI_PI ), delta )
411+
412+ # Fluid substrate standard: phi^pi * tanh(delta / phi^pi). Same form
413+ # the inference omniweight uses (_omniweight_apply).
414+ fluid_delta = _PHI_PI * torch .tanh (delta / _PHI_PI )
415+ weight = torch .exp (fluid_delta ) # bounded in [exp(-phi^pi), 1]
416+
417+ # Per-token CE, weighted by the omniweight ledger.
418+ ce_per_tok = F .cross_entropy (
419+ logits .reshape (- 1 , V ),
420+ targets .reshape (- 1 ),
421+ reduction = 'none' ,
422+ ).reshape (B , T )
423+ ce = (ce_per_tok * weight ).sum () / (weight .sum () + 1e-8 )
424+
425+ # Same FFT-spectrum substrate term as substrate_fft_loss.
426+ fib_freqs = torch .tensor ([1 , 2 , 3 , 5 , 8 , 13 , 21 ], dtype = dtype , device = device )
427+ v_idx = torch .arange (vocab_size , dtype = dtype , device = device )
428+ angles = 2 * math .pi * v_idx .unsqueeze (1 ) * fib_freqs .unsqueeze (0 ) / vocab_size
429+ basis_cos = torch .cos (angles )
430+ basis_sin = torch .sin (angles )
431+ pred = F .softmax (logits , dim = - 1 )
432+ target_onehot = F .one_hot (targets , vocab_size ).to (pred .dtype )
433+ pred_cos = pred @ basis_cos
434+ pred_sin = pred @ basis_sin
435+ tgt_cos = target_onehot @ basis_cos
436+ tgt_sin = target_onehot @ basis_sin
437+ fft_mismatch = ((pred_cos - tgt_cos ) ** 2 + (pred_sin - tgt_sin ) ** 2 ).mean ()
438+ return ce + lambda_substrate * fft_mismatch
0 commit comments