Skip to content

Commit daa75fa

Browse files
committed
Make chaosgrad more robust
1 parent 6def69c commit daa75fa

2 files changed

Lines changed: 391 additions & 25 deletions

File tree

odyssnet/training/chaos_optimizer.py

Lines changed: 226 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ class ChaosGrad(torch.optim.Optimizer):
7979
_FRUST_THRESH = 0.75
8080
_FRUST_NOISE = 0.01
8181
_FRUST_META_RESET = 0.30
82+
_FRUST_IMPROVE_TOL = 1e-4
83+
_FRUST_SCALE_FLOOR = 1e-10
8284

8385
# Cold-start helpers
8486
_GENESIS_SCALAR = 1e-6
@@ -311,6 +313,150 @@ def _row_mode(self, p: torch.Tensor) -> bool:
311313
"""True when this parameter uses per-row meta-params."""
312314
return self._meta_resolution == 'row' and p.dim() >= 2
313315

316+
@staticmethod
317+
def _finite_scalar(x, fallback: float) -> float:
318+
"""Convert value to a finite float; otherwise return fallback."""
319+
try:
320+
if torch.is_tensor(x):
321+
x = float(x.float().mean().item())
322+
else:
323+
x = float(x)
324+
except Exception:
325+
return fallback
326+
if not math.isfinite(x):
327+
return fallback
328+
return x
329+
330+
def _state_requires_reinit(self, state: dict, g_f: torch.Tensor) -> bool:
331+
"""Return True when state is missing/corrupt and must cold-start."""
332+
required = (
333+
'step', 'init_lr', 'grad_ema', 'momentum',
334+
'per_param_lr', 'per_param_beta',
335+
'per_param_decay', 'per_param_alpha', 'v2',
336+
)
337+
if any(k not in state for k in required):
338+
return True
339+
340+
init_lr = self._finite_scalar(state.get('init_lr', 0.0), 0.0)
341+
if init_lr <= 0.0:
342+
return True
343+
344+
for key in ('grad_ema', 'momentum', 'v2'):
345+
t = state.get(key)
346+
if not torch.is_tensor(t) or t.shape != g_f.shape:
347+
return True
348+
if not bool(torch.isfinite(t).all().item()):
349+
return True
350+
351+
return False
352+
353+
def _sanitize_meta_params(self, state: dict, p: torch.Tensor, group: dict) -> None:
354+
"""Keep adaptive meta-params finite, in-range, and shape-consistent."""
355+
row_mode = self._row_mode(p)
356+
rows = p.shape[0] if p.dim() >= 2 else 1
357+
358+
is_hebbian = group.get('is_hebbian', False)
359+
beta_equil = self._finite_scalar(group.get('beta_equil', 0.90), 0.90)
360+
init_decay = 0.0 if is_hebbian else self._finite_scalar(group.get('init_decay', 0.0), 0.0)
361+
362+
init_lr = self._finite_scalar(state.get('init_lr', self._LR_MIN), self._LR_MIN)
363+
init_lr = max(self._LR_MIN, min(self._LR_MAX, init_lr))
364+
state['init_lr'] = init_lr
365+
366+
def _sanitize_scalar(v, lo, hi, fallback):
367+
x = self._finite_scalar(v, fallback)
368+
if lo is not None:
369+
x = max(lo, x)
370+
if hi is not None:
371+
x = min(hi, x)
372+
return x
373+
374+
def _sanitize_row(v, lo, hi, fallback):
375+
if torch.is_tensor(v):
376+
t = v.to(device=p.device, dtype=torch.float32)
377+
if t.shape != (rows,):
378+
if t.numel() == 1:
379+
t = t.reshape(1).expand(rows).clone()
380+
else:
381+
flat = t.reshape(-1)
382+
if flat.numel() >= rows:
383+
t = flat[:rows].clone()
384+
else:
385+
t = torch.full((rows,), fallback, dtype=torch.float32, device=p.device)
386+
t[:flat.numel()] = flat
387+
else:
388+
t = torch.full(
389+
(rows,),
390+
self._finite_scalar(v, fallback),
391+
dtype=torch.float32,
392+
device=p.device,
393+
)
394+
395+
pos_fill = hi if hi is not None else fallback
396+
neg_fill = lo if lo is not None else fallback
397+
t = torch.nan_to_num(t, nan=fallback, posinf=pos_fill, neginf=neg_fill)
398+
if lo is not None or hi is not None:
399+
t.clamp_(
400+
lo if lo is not None else -float('inf'),
401+
hi if hi is not None else float('inf'),
402+
)
403+
return t
404+
405+
if row_mode:
406+
state['per_param_lr'] = _sanitize_row(
407+
state.get('per_param_lr', init_lr),
408+
self._LR_MIN,
409+
self._LR_MAX,
410+
init_lr,
411+
)
412+
state['per_param_beta'] = _sanitize_row(
413+
state.get('per_param_beta', beta_equil),
414+
self._BETA_MIN,
415+
self._BETA_MAX,
416+
beta_equil,
417+
)
418+
if is_hebbian:
419+
state['per_param_decay'] = torch.zeros((rows,), dtype=torch.float32, device=p.device)
420+
else:
421+
state['per_param_decay'] = _sanitize_row(
422+
state.get('per_param_decay', init_decay),
423+
0.0,
424+
self._DECAY_MAX,
425+
init_decay,
426+
)
427+
state['per_param_alpha'] = _sanitize_row(
428+
state.get('per_param_alpha', 0.5),
429+
0.0,
430+
1.0,
431+
0.5,
432+
)
433+
return
434+
435+
state['per_param_lr'] = _sanitize_scalar(
436+
state.get('per_param_lr', init_lr),
437+
self._LR_MIN,
438+
self._LR_MAX,
439+
init_lr,
440+
)
441+
state['per_param_beta'] = _sanitize_scalar(
442+
state.get('per_param_beta', beta_equil),
443+
self._BETA_MIN,
444+
self._BETA_MAX,
445+
beta_equil,
446+
)
447+
state['per_param_decay'] = 0.0 if is_hebbian else _sanitize_scalar(
448+
state.get('per_param_decay', init_decay),
449+
0.0,
450+
self._DECAY_MAX,
451+
init_decay,
452+
)
453+
state['per_param_alpha'] = _sanitize_scalar(
454+
state.get('per_param_alpha', 0.5),
455+
0.0,
456+
1.0,
457+
0.5,
458+
)
459+
314460
def _init_param_state(
315461
self,
316462
p: torch.Tensor,
@@ -554,6 +700,11 @@ def step(self, closure=None):
554700
with torch.enable_grad():
555701
loss = closure()
556702

703+
if not math.isfinite(self._frustration):
704+
self._frustration = 0.0
705+
if not math.isfinite(self._best_loss):
706+
self._best_loss = float('inf')
707+
557708
burst_now = (self._frustration > self._FRUST_THRESH) or self._force_plateau_escape
558709
self._force_plateau_escape = False
559710

@@ -577,33 +728,42 @@ def step(self, closure=None):
577728
)
578729

579730
g_f = p.grad.float().detach()
731+
if not bool(torch.isfinite(g_f).all().item()):
732+
# Do not let NaN/Inf gradients poison state/weights.
733+
self.reset_param_state(p)
734+
continue
735+
736+
# Respect hard architectural constraints in all optimizer
737+
# pathways (signals, moments, preconditioner).
738+
g_signal = g_f
739+
if is_core and g_signal.dim() == 2 and g_signal.shape[0] == g_signal.shape[1]:
740+
g_signal = g_signal.clone()
741+
g_signal.fill_diagonal_(0.0)
580742

581743
# ---- Cold start ----
582744
if not self.state[p]:
583-
self.state[p] = self._init_param_state(p, g_f, group)
745+
self.state[p] = self._init_param_state(p, g_signal, group)
584746

585747
state = self.state[p]
586748

587-
# Neurogenesis/legacy checkpoints can occasionally produce
588-
# partial state payloads. Re-seed missing essentials so step
589-
# logic remains total and shape-safe.
590-
required = (
591-
'step', 'init_lr', 'grad_ema', 'momentum',
592-
'per_param_lr', 'per_param_beta',
593-
'per_param_decay', 'per_param_alpha', 'v2',
594-
)
595-
if any(k not in state for k in required):
596-
self.state[p] = self._init_param_state(p, g_f, group)
749+
# Re-seed invalid state payloads so the update remains stable.
750+
if self._state_requires_reinit(state, g_signal):
751+
self.state[p] = self._init_param_state(p, g_signal, group)
597752
state = self.state[p]
598753

754+
# Sanitize before meta-update math so corrupt payloads cannot
755+
# explode or shape-break the update equations.
756+
self._sanitize_meta_params(state, p, group)
757+
599758
state['step'] += 1
600759
step = state['step']
601760

602761
# ---- Hypergradient signals ----
603-
sigs = self._compute_signals(g_f, state, p)
762+
sigs = self._compute_signals(g_signal, state, p)
604763

605764
# ---- Meta-parameter update ----
606765
self._update_meta_params(state, sigs, group)
766+
self._sanitize_meta_params(state, p, group)
607767

608768
per_lr = state['per_param_lr']
609769
per_beta = state['per_param_beta']
@@ -614,7 +774,7 @@ def step(self, closure=None):
614774

615775
# In row mode per_lr / per_beta are (rows,) tensors. Reshape
616776
# them for broadcasting across the full parameter shape.
617-
row_mode = torch.is_tensor(per_lr)
777+
row_mode = self._row_mode(p)
618778
if row_mode:
619779
view_shape = (-1,) + (1,) * (p.dim() - 1)
620780
per_lr_b = per_lr.view(view_shape)
@@ -637,19 +797,19 @@ def step(self, closure=None):
637797

638798
# ---- Update grad EMA (for next step's signals) ----
639799
state['grad_ema'].mul_(self._SIGNAL_ALPHA).add_(
640-
g_f, alpha=1.0 - self._SIGNAL_ALPHA
800+
g_signal, alpha=1.0 - self._SIGNAL_ALPHA
641801
)
642802

643803
# ---- Gradient centralization ----
644-
g_proc = g_f
645-
if g_f.dim() >= 2:
646-
dims = tuple(range(1, g_f.dim()))
647-
g_mean = g_f.mean(dim=dims, keepdim=True)
804+
g_proc = g_signal
805+
if g_signal.dim() >= 2:
806+
dims = tuple(range(1, g_signal.dim()))
807+
g_mean = g_signal.mean(dim=dims, keepdim=True)
648808
if row_mode and torch.is_tensor(per_alpha):
649809
if torch.any(per_alpha > 1e-3):
650-
g_proc = g_f - per_alpha_b * g_mean
810+
g_proc = g_signal - per_alpha_b * g_mean
651811
elif per_alpha > 1e-3:
652-
g_proc = g_f - per_alpha * g_mean
812+
g_proc = g_signal - per_alpha * g_mean
653813

654814
# ---- Zero diagonal on chaos core gradient ----
655815
if is_core and g_proc.dim() == 2 and g_proc.shape[0] == g_proc.shape[1]:
@@ -664,6 +824,10 @@ def step(self, closure=None):
664824
else:
665825
v.mul_(per_beta).add_(g_proc, alpha=1.0 - per_beta)
666826

827+
# Keep optimizer state in the same constrained subspace as W.
828+
if is_core and v.dim() == 2 and v.shape[0] == v.shape[1]:
829+
v.fill_diagonal_(0.0)
830+
667831
# ---- Frustration burst ----
668832
# Noise is scaled to the current momentum RMS so that the
669833
# perturbation is always a comparable fraction of "how fast we
@@ -686,6 +850,9 @@ def step(self, closure=None):
686850
noise = noise - per_alpha * noise_mean
687851
v.add_(noise)
688852

853+
if is_core and v.dim() == 2 and v.shape[0] == v.shape[1]:
854+
v.fill_diagonal_(0.0)
855+
689856
if burst_type == 'full':
690857
beta_equil = group.get('beta_equil', 0.90)
691858
if row_mode:
@@ -715,7 +882,11 @@ def step(self, closure=None):
715882

716883
# ---- Elementwise second-moment (Adam-style) ----
717884
v2 = state['v2']
718-
v2.mul_(self._BETA2).addcmul_(g_f, g_f, value=1.0 - self._BETA2)
885+
# Keep denominator dynamics consistent with the transformed
886+
# gradient that drives momentum (centralization included).
887+
v2.mul_(self._BETA2).addcmul_(g_proc, g_proc, value=1.0 - self._BETA2)
888+
if is_core and v2.dim() == 2 and v2.shape[0] == v2.shape[1]:
889+
v2.fill_diagonal_(0.0)
719890
bias_corr_v2 = max(1.0 - self._BETA2 ** step, self._EPS)
720891
denom = (v2 / bias_corr_v2).sqrt().add_(self._EPS)
721892

@@ -761,18 +932,48 @@ def report_loss(self, loss_value: float) -> None:
761932
762933
Call this once per optimizer step (the trainer does this automatically
763934
when ChaosGrad is detected as the active optimizer).
935+
936+
The frustration signal is sign-agnostic and scale-robust:
937+
- meaningful relative improvement -> 0.0 signal
938+
- exact plateau / regression -> 1.0 signal
939+
- sub-threshold improvement -> linearly reduced signal
764940
"""
765941
loss = float(loss_value)
766-
if loss < self._best_loss * 0.9999:
767-
self._best_loss = loss
768-
frustration_signal = 0.0
942+
943+
if not math.isfinite(self._frustration):
944+
self._frustration = 0.0
945+
946+
if not math.isfinite(loss):
947+
# Treat NaN/Inf losses as maximal stagnation signal without
948+
# corrupting best_loss / frustration with NaN values.
949+
self._frustration = (
950+
self._frustration * self._FRUST_DECAY
951+
+ 1.0 * (1.0 - self._FRUST_DECAY)
952+
)
953+
self._frustration = max(0.0, min(1.0, self._frustration))
954+
return
955+
956+
if not math.isfinite(self._best_loss):
957+
self._best_loss = loss
958+
frustration_signal = 0.0
769959
else:
770-
frustration_signal = min(1.0, loss / max(self._best_loss, 1e-10))
960+
scale = max(abs(self._best_loss), abs(loss), self._FRUST_SCALE_FLOOR)
961+
rel_improvement = (self._best_loss - loss) / scale
962+
963+
if rel_improvement > self._FRUST_IMPROVE_TOL:
964+
self._best_loss = loss
965+
frustration_signal = 0.0
966+
else:
967+
# 1.0 for plateau/regression, smoothly reduced by small gains
968+
# that do not pass the "new best" threshold.
969+
sub_thr_gain = max(0.0, min(1.0, rel_improvement / self._FRUST_IMPROVE_TOL))
970+
frustration_signal = 1.0 - sub_thr_gain
771971

772972
self._frustration = (
773973
self._frustration * self._FRUST_DECAY
774974
+ frustration_signal * (1.0 - self._FRUST_DECAY)
775975
)
976+
self._frustration = max(0.0, min(1.0, self._frustration))
776977

777978
def trigger_plateau_escape(self) -> None:
778979
"""Manually trigger a frustration burst on the next :meth:`step` call."""

0 commit comments

Comments
 (0)