Skip to content

Commit 18e9fa1

Browse files
Merge pull request #3490 from AI-Hypercomputer:skip_optimizer
PiperOrigin-RevId: 896614970
2 parents b6bdfc4 + 11b7ec4 commit 18e9fa1

5 files changed

Lines changed: 214 additions & 3 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -822,6 +822,15 @@ gradient_clipping_threshold: 1.0
822822
gradient_accumulation_steps: 1
823823

824824
opt_type: "adamw" # one of "adamw", "adam_pax", "sgd", or "muon"
825+
826+
# If True, skip the training step when loss or gradient spike is detected
827+
# No updates for both weights and momentums (if applies)
828+
skip_step_on_spikes: False
829+
# The rolling interval to calculate the mean and standard deviation
830+
skip_step_interval: 128
831+
# The scaling factor to determine if a spike occurred
832+
skip_step_scaling_factor: 6.0
833+
825834
# List of parameter names/patterns to train.
826835
# If non-empty, all other parameters will be frozen. Example: ['.*indexer.*'].
827836
# If empty (default), all parameters are trained.

src/maxtext/configs/types.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22

33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -1211,6 +1211,13 @@ class Optimizer(BaseModel):
12111211
"""Configuration for the optimizer and learning rate schedule."""
12121212

12131213
opt_type: OptimizerType = Field(OptimizerType.ADAMW, description="The type of optimizer to use.")
1214+
skip_step_on_spikes: bool = Field(
1215+
False, description="If True, skip the training step when loss or gradient spike is detected."
1216+
)
1217+
skip_step_interval: PositiveInt = Field(
1218+
128, description="The rolling interval to calculate the mean and standard deviation."
1219+
)
1220+
skip_step_scaling_factor: float = Field(6.0, description="The scaling factor to determine if a spike occurred.")
12141221
gradient_accumulation_steps: PositiveInt = Field(
12151222
1, description="Number of steps to accumulate gradients before updating."
12161223
)

src/maxtext/optimizers/optimizers.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import jax.numpy as jnp
2121

2222
import optax
23+
from maxtext.utils import max_logging
2324
from optax.contrib._muon import muon
2425
from maxtext.utils.muon_utils import get_muon_weight_dimension_numbers
2526

@@ -48,6 +49,124 @@ def get_adamw_mask(config):
4849
return _get_path_mask_fn(getattr(config, "adamw_mask", None), match_returns_true=False)
4950

5051

52+
def _compute_rolling_stats(arr: jax.Array, count: jax.Array, interval: int):
53+
"""Computes mean and unbiased std (Bessel's correction) over a rolling window."""
54+
valid_elements = jnp.minimum(count, interval)
55+
safe_elements = jnp.maximum(1, valid_elements)
56+
mask = jnp.arange(interval) < valid_elements
57+
58+
mean = jnp.sum(jnp.where(mask, arr, 0.0)) / safe_elements
59+
sq_diff = jnp.where(mask, (arr - mean) ** 2, 0.0)
60+
61+
# Use Bessel's correction (N - 1) for unbiased variance to align with torch.std
62+
variance = jnp.sum(sq_diff) / jnp.maximum(1, valid_elements - 1)
63+
std = jnp.sqrt(variance)
64+
return mean, std
65+
66+
67+
def skip_step_on_spikes(
68+
inner_opt: optax.GradientTransformation, interval: int, scaling_factor: float
69+
) -> optax.GradientTransformationExtraArgs:
70+
"""Wrapper that skips updates when loss or grad_norm spike.
71+
72+
This wrapper calculates a rolling mean and standard deviation (using
73+
Bessel's correction) over the last `interval` steps for both the loss
74+
and the gradient norm. If the current step's loss or gradient norm
75+
exceeds `mean + scaling_factor * std`, the update is zeroed and the
76+
optimizer state is not advanced, effectively skipping the step.
77+
78+
Reference implementation:
79+
https://github.com/allenai/OLMo-core/blob/c757b7c3c15197154c753d883330afbfa4869dcc/src/olmo_core/optim/skip_step_optimizer.py#L12
80+
81+
Args:
82+
inner_opt: The inner Optax gradient transformation to wrap.
83+
interval: The number of recent steps to use for calculating mean and std.
84+
scaling_factor: The multiplier for standard deviation to set the spike threshold.
85+
86+
Returns:
87+
An optax.GradientTransformationExtraArgs that skips spikes.
88+
"""
89+
90+
def init_fn(params):
91+
return {
92+
"inner_state": inner_opt.init(params),
93+
"losses": jnp.zeros(interval, dtype=jnp.float32),
94+
"grad_norms": jnp.zeros(interval, dtype=jnp.float32),
95+
"count": jnp.zeros((), dtype=jnp.int32),
96+
"is_skipped": jnp.array(False, dtype=jnp.bool_),
97+
}
98+
99+
def update_fn(updates, state, params=None, **extra_args):
100+
# Using `pop()` removes `loss` and `grad_norm` from `extra_args` before they are
101+
# passed downstream to `inner_opt.update()`. This prevents `TypeError` if the
102+
# inner optimizer doesn't explicitly accept these as `kwargs`.
103+
loss = extra_args.pop("loss", None)
104+
grad_norm = extra_args.pop("grad_norm", None)
105+
106+
# Fallback to standard update if loss is not provided
107+
if loss is None:
108+
inner_updates, new_inner_state = inner_opt.update(updates, state["inner_state"], params, **extra_args)
109+
return inner_updates, {
110+
"inner_state": new_inner_state,
111+
"losses": state["losses"],
112+
"grad_norms": state["grad_norms"],
113+
"count": state["count"],
114+
"is_skipped": jnp.array(False, dtype=jnp.bool_),
115+
}
116+
117+
count = state["count"]
118+
losses = state["losses"]
119+
grad_norms = state["grad_norms"]
120+
121+
# Compute rolling stats
122+
loss_mean, loss_std = _compute_rolling_stats(losses, count, interval)
123+
grad_norm_mean, grad_norm_std = _compute_rolling_stats(grad_norms, count, interval)
124+
125+
# Check if the current metrics are within the allowed thresholds
126+
is_loss_ok = (loss - loss_mean) <= scaling_factor * loss_std
127+
if grad_norm is not None:
128+
is_grad_norm_ok = (grad_norm - grad_norm_mean) <= scaling_factor * grad_norm_std
129+
is_ok = jnp.logical_and(is_loss_ok, is_grad_norm_ok)
130+
else:
131+
is_ok = is_loss_ok
132+
133+
# Only enforce skip if we have at least half the interval filled (or 2 elements minimum)
134+
min_history = max(2, interval // 2)
135+
is_warmup = (count + 1) < min_history
136+
is_ok = jnp.logical_or(is_warmup, is_ok)
137+
138+
# Conditionally execute the inner optimizer to prevent momentum poisoning
139+
def do_update():
140+
return inner_opt.update(updates, state["inner_state"], params, **extra_args)
141+
142+
def skip_update():
143+
# use callback to work with jax.jit and jax.lax.cond for logging
144+
jax.debug.callback(lambda c: max_logging.warning(f"Step {c}: Optimizer step skipped due to spike."), count)
145+
inner_updates = jax.tree_util.tree_map(jnp.zeros_like, updates)
146+
return inner_updates, state["inner_state"]
147+
148+
inner_updates, new_inner_state = jax.lax.cond(is_ok, do_update, skip_update)
149+
150+
# Update rolling buffers (append even if skipped so spikes can become the new baseline)
151+
idx = count % interval
152+
new_losses = losses.at[idx].set(loss)
153+
154+
new_grad_norms = grad_norms
155+
if grad_norm is not None:
156+
new_grad_norms = grad_norms.at[idx].set(grad_norm)
157+
158+
new_state = {
159+
"inner_state": new_inner_state,
160+
"losses": new_losses,
161+
"grad_norms": new_grad_norms,
162+
"count": count + 1,
163+
"is_skipped": jnp.logical_not(is_ok),
164+
}
165+
return inner_updates, new_state
166+
167+
return optax.GradientTransformationExtraArgs(init_fn, update_fn)
168+
169+
51170
def get_optimizer(config, learning_rate_schedule, model=None):
52171
"""Create optimizer."""
53172
if config.opt_type == "adamw":
@@ -100,6 +219,13 @@ def get_optimizer(config, learning_rate_schedule, model=None):
100219
else:
101220
raise ValueError(f"{config.opt_type=} is not a supported.")
102221

222+
if getattr(config, "skip_step_on_spikes", False):
223+
base_opt = skip_step_on_spikes(
224+
base_opt,
225+
interval=config.skip_step_interval,
226+
scaling_factor=config.skip_step_scaling_factor,
227+
)
228+
103229
# If a whitelist of trainable parameters is provided, freeze everything else.
104230
# When trainable_parameters_mask is empty, freeze_mask_fn is None and all parameters are trained.
105231
trainable_patterns = getattr(config, "trainable_parameters_mask", None)

src/maxtext/trainers/pre_train/train.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from absl import app
2828

2929
import numpy as np
30+
import optax
3031

3132
import pathwaysutils # pylint: disable=unused-import
3233

@@ -399,7 +400,20 @@ def move(path, value):
399400
jax.tree_util.tree_map_with_path(move, state_mesh_shardings.params),
400401
)
401402
)
402-
new_state = state.apply_gradients(grads=grads)
403+
404+
if getattr(config, "skip_step_on_spikes", False):
405+
grad_norm = max_utils.l2norm_pytree(grads)
406+
# TrainState.apply_gradients doesn't pass **kwargs to tx.update, so we unpack it manually.
407+
updates, new_opt_state = state.tx.update(grads, state.opt_state, state.params, loss=loss, grad_norm=grad_norm)
408+
new_params = optax.apply_updates(state.params, updates)
409+
410+
new_state = state.replace(
411+
step=state.step + 1,
412+
params=new_params,
413+
opt_state=new_opt_state,
414+
)
415+
else:
416+
new_state = state.apply_gradients(grads=grads)
403417

404418
# Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family
405419
if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None:

tests/unit/optimizers_test.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import unittest
1818
from unittest.mock import patch
1919
import jax
20+
import optax
21+
import jax.numpy as jnp
2022

2123
import pytest
2224
from absl.testing import parameterized
@@ -428,5 +430,58 @@ def learning_rate_schedule(step):
428430
self.assertFalse(jax.numpy.all(updates["layer1"]["kernel"] == 0))
429431

430432

433+
class SkipStepOnSpikesTest(parameterized.TestCase):
434+
"""Tests for the skip_step_on_spikes optimizer wrapper."""
435+
436+
def _run_spike_test(self, spike_kwargs):
437+
inner_opt = optax.sgd(0.1)
438+
opt = optimizers.skip_step_on_spikes(inner_opt, interval=4, scaling_factor=1.0)
439+
440+
params = {"x": jnp.array([1.0])}
441+
opt_state = opt.init(params)
442+
443+
# Base kwargs for warmup
444+
base_kwargs = {k: jnp.array(1.0) for k in spike_kwargs.keys()}
445+
446+
# Step 0: count = 0 < 2, will not skip (count should be >= interval / 2)
447+
updates, opt_state = opt.update({"x": jnp.array([1.0])}, opt_state, params, **base_kwargs)
448+
self.assertFalse(jnp.all(updates["x"] == 0.0))
449+
self.assertFalse(opt_state["is_skipped"])
450+
451+
# Step 1: count = 1 < 2, will not skip. mean=1.0, std=0.0 (count should be >= interval / 2)
452+
updates, opt_state = opt.update({"x": jnp.array([1.0])}, opt_state, params, **base_kwargs)
453+
self.assertFalse(jnp.all(updates["x"] == 0.0))
454+
self.assertFalse(opt_state["is_skipped"])
455+
456+
# Step 2: count = 2. Spike!
457+
spike_kwargs_jnp = {k: jnp.array(v) for k, v in spike_kwargs.items()}
458+
updates, opt_state = opt.update({"x": jnp.array([1.0])}, opt_state, params, **spike_kwargs_jnp)
459+
self.assertTrue(jnp.all(updates["x"] == 0.0))
460+
self.assertTrue(opt_state["is_skipped"])
461+
462+
def test_skip_step_on_loss_spike(self):
463+
self._run_spike_test({"loss": 100.0})
464+
465+
def test_skip_step_on_grad_norm_spike(self):
466+
self._run_spike_test({"loss": 1.0, "grad_norm": 100.0})
467+
468+
def test_skip_step_on_both_spike(self):
469+
self._run_spike_test({"loss": 100.0, "grad_norm": 100.0})
470+
471+
def test_no_skip_without_kwargs(self):
472+
inner_opt = optax.sgd(0.1)
473+
opt = optimizers.skip_step_on_spikes(inner_opt, interval=4, scaling_factor=1.0)
474+
475+
params = {"x": jnp.array([1.0])}
476+
opt_state = opt.init(params)
477+
478+
# Missing kwargs should act normally
479+
updates, opt_state = opt.update({"x": jnp.array([1.0])}, opt_state, params)
480+
self.assertFalse(jnp.all(updates["x"] == 0.0))
481+
self.assertFalse(opt_state["is_skipped"])
482+
# Count shouldn't have incremented
483+
self.assertEqual(opt_state["count"], 0)
484+
485+
431486
if __name__ == "__main__":
432487
unittest.main()

0 commit comments

Comments
 (0)