@@ -79,6 +79,8 @@ class ChaosGrad(torch.optim.Optimizer):
7979 _FRUST_THRESH = 0.75
8080 _FRUST_NOISE = 0.01
8181 _FRUST_META_RESET = 0.30
82+ _FRUST_IMPROVE_TOL = 1e-4
83+ _FRUST_SCALE_FLOOR = 1e-10
8284
8385 # Cold-start helpers
8486 _GENESIS_SCALAR = 1e-6
@@ -311,6 +313,150 @@ def _row_mode(self, p: torch.Tensor) -> bool:
311313 """True when this parameter uses per-row meta-params."""
312314 return self ._meta_resolution == 'row' and p .dim () >= 2
313315
316+ @staticmethod
317+ def _finite_scalar (x , fallback : float ) -> float :
318+ """Convert value to a finite float; otherwise return fallback."""
319+ try :
320+ if torch .is_tensor (x ):
321+ x = float (x .float ().mean ().item ())
322+ else :
323+ x = float (x )
324+ except Exception :
325+ return fallback
326+ if not math .isfinite (x ):
327+ return fallback
328+ return x
329+
330+ def _state_requires_reinit (self , state : dict , g_f : torch .Tensor ) -> bool :
331+ """Return True when state is missing/corrupt and must cold-start."""
332+ required = (
333+ 'step' , 'init_lr' , 'grad_ema' , 'momentum' ,
334+ 'per_param_lr' , 'per_param_beta' ,
335+ 'per_param_decay' , 'per_param_alpha' , 'v2' ,
336+ )
337+ if any (k not in state for k in required ):
338+ return True
339+
340+ init_lr = self ._finite_scalar (state .get ('init_lr' , 0.0 ), 0.0 )
341+ if init_lr <= 0.0 :
342+ return True
343+
344+ for key in ('grad_ema' , 'momentum' , 'v2' ):
345+ t = state .get (key )
346+ if not torch .is_tensor (t ) or t .shape != g_f .shape :
347+ return True
348+ if not bool (torch .isfinite (t ).all ().item ()):
349+ return True
350+
351+ return False
352+
353+ def _sanitize_meta_params (self , state : dict , p : torch .Tensor , group : dict ) -> None :
354+ """Keep adaptive meta-params finite, in-range, and shape-consistent."""
355+ row_mode = self ._row_mode (p )
356+ rows = p .shape [0 ] if p .dim () >= 2 else 1
357+
358+ is_hebbian = group .get ('is_hebbian' , False )
359+ beta_equil = self ._finite_scalar (group .get ('beta_equil' , 0.90 ), 0.90 )
360+ init_decay = 0.0 if is_hebbian else self ._finite_scalar (group .get ('init_decay' , 0.0 ), 0.0 )
361+
362+ init_lr = self ._finite_scalar (state .get ('init_lr' , self ._LR_MIN ), self ._LR_MIN )
363+ init_lr = max (self ._LR_MIN , min (self ._LR_MAX , init_lr ))
364+ state ['init_lr' ] = init_lr
365+
366+ def _sanitize_scalar (v , lo , hi , fallback ):
367+ x = self ._finite_scalar (v , fallback )
368+ if lo is not None :
369+ x = max (lo , x )
370+ if hi is not None :
371+ x = min (hi , x )
372+ return x
373+
374+ def _sanitize_row (v , lo , hi , fallback ):
375+ if torch .is_tensor (v ):
376+ t = v .to (device = p .device , dtype = torch .float32 )
377+ if t .shape != (rows ,):
378+ if t .numel () == 1 :
379+ t = t .reshape (1 ).expand (rows ).clone ()
380+ else :
381+ flat = t .reshape (- 1 )
382+ if flat .numel () >= rows :
383+ t = flat [:rows ].clone ()
384+ else :
385+ t = torch .full ((rows ,), fallback , dtype = torch .float32 , device = p .device )
386+ t [:flat .numel ()] = flat
387+ else :
388+ t = torch .full (
389+ (rows ,),
390+ self ._finite_scalar (v , fallback ),
391+ dtype = torch .float32 ,
392+ device = p .device ,
393+ )
394+
395+ pos_fill = hi if hi is not None else fallback
396+ neg_fill = lo if lo is not None else fallback
397+ t = torch .nan_to_num (t , nan = fallback , posinf = pos_fill , neginf = neg_fill )
398+ if lo is not None or hi is not None :
399+ t .clamp_ (
400+ lo if lo is not None else - float ('inf' ),
401+ hi if hi is not None else float ('inf' ),
402+ )
403+ return t
404+
405+ if row_mode :
406+ state ['per_param_lr' ] = _sanitize_row (
407+ state .get ('per_param_lr' , init_lr ),
408+ self ._LR_MIN ,
409+ self ._LR_MAX ,
410+ init_lr ,
411+ )
412+ state ['per_param_beta' ] = _sanitize_row (
413+ state .get ('per_param_beta' , beta_equil ),
414+ self ._BETA_MIN ,
415+ self ._BETA_MAX ,
416+ beta_equil ,
417+ )
418+ if is_hebbian :
419+ state ['per_param_decay' ] = torch .zeros ((rows ,), dtype = torch .float32 , device = p .device )
420+ else :
421+ state ['per_param_decay' ] = _sanitize_row (
422+ state .get ('per_param_decay' , init_decay ),
423+ 0.0 ,
424+ self ._DECAY_MAX ,
425+ init_decay ,
426+ )
427+ state ['per_param_alpha' ] = _sanitize_row (
428+ state .get ('per_param_alpha' , 0.5 ),
429+ 0.0 ,
430+ 1.0 ,
431+ 0.5 ,
432+ )
433+ return
434+
435+ state ['per_param_lr' ] = _sanitize_scalar (
436+ state .get ('per_param_lr' , init_lr ),
437+ self ._LR_MIN ,
438+ self ._LR_MAX ,
439+ init_lr ,
440+ )
441+ state ['per_param_beta' ] = _sanitize_scalar (
442+ state .get ('per_param_beta' , beta_equil ),
443+ self ._BETA_MIN ,
444+ self ._BETA_MAX ,
445+ beta_equil ,
446+ )
447+ state ['per_param_decay' ] = 0.0 if is_hebbian else _sanitize_scalar (
448+ state .get ('per_param_decay' , init_decay ),
449+ 0.0 ,
450+ self ._DECAY_MAX ,
451+ init_decay ,
452+ )
453+ state ['per_param_alpha' ] = _sanitize_scalar (
454+ state .get ('per_param_alpha' , 0.5 ),
455+ 0.0 ,
456+ 1.0 ,
457+ 0.5 ,
458+ )
459+
314460 def _init_param_state (
315461 self ,
316462 p : torch .Tensor ,
@@ -554,6 +700,11 @@ def step(self, closure=None):
554700 with torch .enable_grad ():
555701 loss = closure ()
556702
703+ if not math .isfinite (self ._frustration ):
704+ self ._frustration = 0.0
705+ if not math .isfinite (self ._best_loss ):
706+ self ._best_loss = float ('inf' )
707+
557708 burst_now = (self ._frustration > self ._FRUST_THRESH ) or self ._force_plateau_escape
558709 self ._force_plateau_escape = False
559710
@@ -577,33 +728,42 @@ def step(self, closure=None):
577728 )
578729
579730 g_f = p .grad .float ().detach ()
731+ if not bool (torch .isfinite (g_f ).all ().item ()):
732+ # Do not let NaN/Inf gradients poison state/weights.
733+ self .reset_param_state (p )
734+ continue
735+
736+ # Respect hard architectural constraints in all optimizer
737+ # pathways (signals, moments, preconditioner).
738+ g_signal = g_f
739+ if is_core and g_signal .dim () == 2 and g_signal .shape [0 ] == g_signal .shape [1 ]:
740+ g_signal = g_signal .clone ()
741+ g_signal .fill_diagonal_ (0.0 )
580742
581743 # ---- Cold start ----
582744 if not self .state [p ]:
583- self .state [p ] = self ._init_param_state (p , g_f , group )
745+ self .state [p ] = self ._init_param_state (p , g_signal , group )
584746
585747 state = self .state [p ]
586748
587- # Neurogenesis/legacy checkpoints can occasionally produce
588- # partial state payloads. Re-seed missing essentials so step
589- # logic remains total and shape-safe.
590- required = (
591- 'step' , 'init_lr' , 'grad_ema' , 'momentum' ,
592- 'per_param_lr' , 'per_param_beta' ,
593- 'per_param_decay' , 'per_param_alpha' , 'v2' ,
594- )
595- if any (k not in state for k in required ):
596- self .state [p ] = self ._init_param_state (p , g_f , group )
749+ # Re-seed invalid state payloads so the update remains stable.
750+ if self ._state_requires_reinit (state , g_signal ):
751+ self .state [p ] = self ._init_param_state (p , g_signal , group )
597752 state = self .state [p ]
598753
754+ # Sanitize before meta-update math so corrupt payloads cannot
755+ # explode or shape-break the update equations.
756+ self ._sanitize_meta_params (state , p , group )
757+
599758 state ['step' ] += 1
600759 step = state ['step' ]
601760
602761 # ---- Hypergradient signals ----
603- sigs = self ._compute_signals (g_f , state , p )
762+ sigs = self ._compute_signals (g_signal , state , p )
604763
605764 # ---- Meta-parameter update ----
606765 self ._update_meta_params (state , sigs , group )
766+ self ._sanitize_meta_params (state , p , group )
607767
608768 per_lr = state ['per_param_lr' ]
609769 per_beta = state ['per_param_beta' ]
@@ -614,7 +774,7 @@ def step(self, closure=None):
614774
615775 # In row mode per_lr / per_beta are (rows,) tensors. Reshape
616776 # them for broadcasting across the full parameter shape.
617- row_mode = torch . is_tensor ( per_lr )
777+ row_mode = self . _row_mode ( p )
618778 if row_mode :
619779 view_shape = (- 1 ,) + (1 ,) * (p .dim () - 1 )
620780 per_lr_b = per_lr .view (view_shape )
@@ -637,19 +797,19 @@ def step(self, closure=None):
637797
638798 # ---- Update grad EMA (for next step's signals) ----
639799 state ['grad_ema' ].mul_ (self ._SIGNAL_ALPHA ).add_ (
640- g_f , alpha = 1.0 - self ._SIGNAL_ALPHA
800+ g_signal , alpha = 1.0 - self ._SIGNAL_ALPHA
641801 )
642802
643803 # ---- Gradient centralization ----
644- g_proc = g_f
645- if g_f .dim () >= 2 :
646- dims = tuple (range (1 , g_f .dim ()))
647- g_mean = g_f .mean (dim = dims , keepdim = True )
804+ g_proc = g_signal
805+ if g_signal .dim () >= 2 :
806+ dims = tuple (range (1 , g_signal .dim ()))
807+ g_mean = g_signal .mean (dim = dims , keepdim = True )
648808 if row_mode and torch .is_tensor (per_alpha ):
649809 if torch .any (per_alpha > 1e-3 ):
650- g_proc = g_f - per_alpha_b * g_mean
810+ g_proc = g_signal - per_alpha_b * g_mean
651811 elif per_alpha > 1e-3 :
652- g_proc = g_f - per_alpha * g_mean
812+ g_proc = g_signal - per_alpha * g_mean
653813
654814 # ---- Zero diagonal on chaos core gradient ----
655815 if is_core and g_proc .dim () == 2 and g_proc .shape [0 ] == g_proc .shape [1 ]:
@@ -664,6 +824,10 @@ def step(self, closure=None):
664824 else :
665825 v .mul_ (per_beta ).add_ (g_proc , alpha = 1.0 - per_beta )
666826
827+ # Keep optimizer state in the same constrained subspace as W.
828+ if is_core and v .dim () == 2 and v .shape [0 ] == v .shape [1 ]:
829+ v .fill_diagonal_ (0.0 )
830+
667831 # ---- Frustration burst ----
668832 # Noise is scaled to the current momentum RMS so that the
669833 # perturbation is always a comparable fraction of "how fast we
@@ -686,6 +850,9 @@ def step(self, closure=None):
686850 noise = noise - per_alpha * noise_mean
687851 v .add_ (noise )
688852
853+ if is_core and v .dim () == 2 and v .shape [0 ] == v .shape [1 ]:
854+ v .fill_diagonal_ (0.0 )
855+
689856 if burst_type == 'full' :
690857 beta_equil = group .get ('beta_equil' , 0.90 )
691858 if row_mode :
@@ -715,7 +882,11 @@ def step(self, closure=None):
715882
716883 # ---- Elementwise second-moment (Adam-style) ----
717884 v2 = state ['v2' ]
718- v2 .mul_ (self ._BETA2 ).addcmul_ (g_f , g_f , value = 1.0 - self ._BETA2 )
885+ # Keep denominator dynamics consistent with the transformed
886+ # gradient that drives momentum (centralization included).
887+ v2 .mul_ (self ._BETA2 ).addcmul_ (g_proc , g_proc , value = 1.0 - self ._BETA2 )
888+ if is_core and v2 .dim () == 2 and v2 .shape [0 ] == v2 .shape [1 ]:
889+ v2 .fill_diagonal_ (0.0 )
719890 bias_corr_v2 = max (1.0 - self ._BETA2 ** step , self ._EPS )
720891 denom = (v2 / bias_corr_v2 ).sqrt ().add_ (self ._EPS )
721892
@@ -761,18 +932,48 @@ def report_loss(self, loss_value: float) -> None:
761932
762933 Call this once per optimizer step (the trainer does this automatically
763934 when ChaosGrad is detected as the active optimizer).
935+
936+ The frustration signal is sign-agnostic and scale-robust:
937+ - meaningful relative improvement -> 0.0 signal
938+ - exact plateau / regression -> 1.0 signal
939+ - sub-threshold improvement -> linearly reduced signal
764940 """
765941 loss = float (loss_value )
766- if loss < self ._best_loss * 0.9999 :
767- self ._best_loss = loss
768- frustration_signal = 0.0
942+
943+ if not math .isfinite (self ._frustration ):
944+ self ._frustration = 0.0
945+
946+ if not math .isfinite (loss ):
947+ # Treat NaN/Inf losses as maximal stagnation signal without
948+ # corrupting best_loss / frustration with NaN values.
949+ self ._frustration = (
950+ self ._frustration * self ._FRUST_DECAY
951+ + 1.0 * (1.0 - self ._FRUST_DECAY )
952+ )
953+ self ._frustration = max (0.0 , min (1.0 , self ._frustration ))
954+ return
955+
956+ if not math .isfinite (self ._best_loss ):
957+ self ._best_loss = loss
958+ frustration_signal = 0.0
769959 else :
770- frustration_signal = min (1.0 , loss / max (self ._best_loss , 1e-10 ))
960+ scale = max (abs (self ._best_loss ), abs (loss ), self ._FRUST_SCALE_FLOOR )
961+ rel_improvement = (self ._best_loss - loss ) / scale
962+
963+ if rel_improvement > self ._FRUST_IMPROVE_TOL :
964+ self ._best_loss = loss
965+ frustration_signal = 0.0
966+ else :
967+ # 1.0 for plateau/regression, smoothly reduced by small gains
968+ # that do not pass the "new best" threshold.
969+ sub_thr_gain = max (0.0 , min (1.0 , rel_improvement / self ._FRUST_IMPROVE_TOL ))
970+ frustration_signal = 1.0 - sub_thr_gain
771971
772972 self ._frustration = (
773973 self ._frustration * self ._FRUST_DECAY
774974 + frustration_signal * (1.0 - self ._FRUST_DECAY )
775975 )
976+ self ._frustration = max (0.0 , min (1.0 , self ._frustration ))
776977
777978 def trigger_plateau_escape (self ) -> None :
778979 """Manually trigger a frustration burst on the next :meth:`step` call."""
0 commit comments