Skip to content

Commit 5b8835d

Browse files
committed
Support option to skip the optimizer for training step
1 parent dc29039 commit 5b8835d

5 files changed

Lines changed: 204 additions & 3 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -790,6 +790,15 @@ gradient_clipping_threshold: 1.0
790790
gradient_accumulation_steps: 1
791791

792792
opt_type: "adamw" # one of "adamw", "adam_pax", "sgd", or "muon"
793+
794+
# If True, skip the training step when loss or gradient spike is detected
795+
# No updates for both weights and momentums (if applies)
796+
skip_step_on_spikes: False
797+
# The rolling interval to calculate the mean and standard deviation
798+
skip_step_interval: 128
799+
# The scaling factor to determine if a spike occurred
800+
skip_step_scaling_factor: 6.0
801+
793802
# List of parameter names/patterns to train.
794803
# If non-empty, all other parameters will be frozen. Example: ['.*indexer.*'].
795804
# If empty (default), all parameters are trained.

src/maxtext/configs/types.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22

33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -1150,6 +1150,13 @@ class Optimizer(BaseModel):
11501150
"""Configuration for the optimizer and learning rate schedule."""
11511151

11521152
opt_type: OptimizerType = Field(OptimizerType.ADAMW, description="The type of optimizer to use.")
1153+
skip_step_on_spikes: bool = Field(
1154+
False, description="If True, skip the training step when loss or gradient spike is detected."
1155+
)
1156+
skip_step_interval: PositiveInt = Field(
1157+
128, description="The rolling interval to calculate the mean and standard deviation."
1158+
)
1159+
skip_step_scaling_factor: float = Field(6.0, description="The scaling factor to determine if a spike occurred.")
11531160
gradient_accumulation_steps: PositiveInt = Field(
11541161
1, description="Number of steps to accumulate gradients before updating."
11551162
)

src/maxtext/optimizers/optimizers.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,119 @@ def get_adamw_mask(config):
4848
return _get_path_mask_fn(getattr(config, "adamw_mask", None), match_returns_true=False)
4949

5050

51+
def _compute_rolling_stats(arr: jax.Array, count: jax.Array, interval: int):
52+
"""Computes mean and unbiased std (Bessel's correction) over a rolling window."""
53+
valid_elements = jnp.minimum(count, interval)
54+
safe_elements = jnp.maximum(1, valid_elements)
55+
mask = jnp.arange(interval) < valid_elements
56+
57+
mean = jnp.sum(jnp.where(mask, arr, 0.0)) / safe_elements
58+
sq_diff = jnp.where(mask, (arr - mean) ** 2, 0.0)
59+
60+
# Use Bessel's correction (N - 1) for unbiased variance to align with torch.std
61+
variance = jnp.sum(sq_diff) / jnp.maximum(1, valid_elements - 1)
62+
std = jnp.sqrt(variance)
63+
return mean, std
64+
65+
66+
def skip_step_on_spikes(
67+
inner_opt: optax.GradientTransformation, interval: int, scaling_factor: float
68+
) -> optax.GradientTransformationExtraArgs:
69+
"""Wrapper that skips updates when loss or grad_norm spike.
70+
71+
This wrapper calculates a rolling mean and standard deviation (using
72+
Bessel's correction) over the last `interval` steps for both the loss
73+
and the gradient norm. If the current step's loss or gradient norm
74+
exceeds `mean + scaling_factor * std`, the update is zeroed and the
75+
optimizer state is not advanced, effectively skipping the step.
76+
77+
Reference implementation:
78+
https://github.com/allenai/OLMo-core/blob/c757b7c3c15197154c753d883330afbfa4869dcc/src/olmo_core/optim/skip_step_optimizer.py#L12
79+
80+
Args:
81+
inner_opt: The inner Optax gradient transformation to wrap.
82+
interval: The number of recent steps to use for calculating mean and std.
83+
scaling_factor: The multiplier for standard deviation to set the spike threshold.
84+
85+
Returns:
86+
An optax.GradientTransformationExtraArgs that skips spikes.
87+
"""
88+
89+
def init_fn(params):
90+
return {
91+
"inner_state": inner_opt.init(params),
92+
"losses": jnp.zeros(interval, dtype=jnp.float32),
93+
"grad_norms": jnp.zeros(interval, dtype=jnp.float32),
94+
"count": jnp.zeros((), dtype=jnp.int32),
95+
}
96+
97+
def update_fn(updates, state, params=None, **extra_args):
98+
# Using `pop()` removes `loss` and `grad_norm` from `extra_args` before they are
99+
# passed downstream to `inner_opt.update()`. This prevents `TypeError` if the
100+
# inner optimizer doesn't explicitly accept these as `kwargs`.
101+
loss = extra_args.pop("loss", None)
102+
grad_norm = extra_args.pop("grad_norm", None)
103+
104+
# Fallback to standard update if loss is not provided
105+
if loss is None:
106+
inner_updates, new_inner_state = inner_opt.update(updates, state["inner_state"], params, **extra_args)
107+
return inner_updates, {
108+
"inner_state": new_inner_state,
109+
"losses": state["losses"],
110+
"grad_norms": state["grad_norms"],
111+
"count": state["count"],
112+
}
113+
114+
count = state["count"]
115+
losses = state["losses"]
116+
grad_norms = state["grad_norms"]
117+
118+
# Compute rolling stats
119+
loss_mean, loss_std = _compute_rolling_stats(losses, count, interval)
120+
grad_norm_mean, grad_norm_std = _compute_rolling_stats(grad_norms, count, interval)
121+
122+
# Check if the current metrics are within the allowed thresholds
123+
is_loss_ok = (loss - loss_mean) <= scaling_factor * loss_std
124+
if grad_norm is not None:
125+
is_grad_norm_ok = (grad_norm - grad_norm_mean) <= scaling_factor * grad_norm_std
126+
is_ok = jnp.logical_and(is_loss_ok, is_grad_norm_ok)
127+
else:
128+
is_ok = is_loss_ok
129+
130+
# Only enforce skip if we have at least half the interval filled (or 2 elements minimum)
131+
min_history = max(2, interval // 2)
132+
is_warmup = (count + 1) < min_history
133+
is_ok = jnp.logical_or(is_warmup, is_ok)
134+
135+
# Conditionally execute the inner optimizer to prevent momentum poisoning
136+
def do_update():
137+
return inner_opt.update(updates, state["inner_state"], params, **extra_args)
138+
139+
def skip_update():
140+
inner_updates = jax.tree_util.tree_map(jnp.zeros_like, updates)
141+
return inner_updates, state["inner_state"]
142+
143+
inner_updates, new_inner_state = jax.lax.cond(is_ok, do_update, skip_update)
144+
145+
# Update rolling buffers (we append even if skipped so spikes can become the new baseline)
146+
idx = count % interval
147+
new_losses = losses.at[idx].set(loss)
148+
149+
new_grad_norms = grad_norms
150+
if grad_norm is not None:
151+
new_grad_norms = grad_norms.at[idx].set(grad_norm)
152+
153+
new_state = {
154+
"inner_state": new_inner_state,
155+
"losses": new_losses,
156+
"grad_norms": new_grad_norms,
157+
"count": count + 1,
158+
}
159+
return inner_updates, new_state
160+
161+
return optax.GradientTransformationExtraArgs(init_fn, update_fn)
162+
163+
51164
def get_optimizer(config, learning_rate_schedule, model=None):
52165
"""Create optimizer."""
53166
if config.opt_type == "adamw":
@@ -100,6 +213,13 @@ def get_optimizer(config, learning_rate_schedule, model=None):
100213
else:
101214
raise ValueError(f"{config.opt_type=} is not a supported.")
102215

216+
if getattr(config, "skip_step_on_spikes", False):
217+
base_opt = skip_step_on_spikes(
218+
base_opt,
219+
interval=config.skip_step_interval,
220+
scaling_factor=config.skip_step_scaling_factor,
221+
)
222+
103223
# If a whitelist of trainable parameters is provided, freeze everything else.
104224
# When trainable_parameters_mask is empty, freeze_mask_fn is None and all parameters are trained.
105225
trainable_patterns = getattr(config, "trainable_parameters_mask", None)

src/maxtext/trainers/pre_train/train.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from absl import app
2828

2929
import numpy as np
30+
import optax
3031

3132
import pathwaysutils # pylint: disable=unused-import
3233

@@ -391,7 +392,20 @@ def move(path, value):
391392
jax.tree_util.tree_map_with_path(move, state_mesh_shardings.params),
392393
)
393394
)
394-
new_state = state.apply_gradients(grads=grads)
395+
396+
if getattr(config, "skip_step_on_spikes", False):
397+
grad_norm = max_utils.l2norm_pytree(grads)
398+
# TrainState.apply_gradients doesn't pass **kwargs to tx.update, so we unpack it manually.
399+
updates, new_opt_state = state.tx.update(grads, state.opt_state, state.params, loss=loss, grad_norm=grad_norm)
400+
new_params = optax.apply_updates(state.params, updates)
401+
402+
new_state = state.replace(
403+
step=state.step + 1,
404+
params=new_params,
405+
opt_state=new_opt_state,
406+
)
407+
else:
408+
new_state = state.apply_gradients(grads=grads)
395409

396410
# Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family
397411
if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None:

tests/unit/optimizers_test.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import unittest
1818
from unittest.mock import patch
1919
import jax
20+
import optax
21+
import jax.numpy as jnp
2022

2123
import pytest
2224
from absl.testing import parameterized
@@ -428,5 +430,54 @@ def learning_rate_schedule(step):
428430
self.assertFalse(jax.numpy.all(updates["layer1"]["kernel"] == 0))
429431

430432

433+
class SkipStepOnSpikesTest(parameterized.TestCase):
434+
"""Tests for the skip_step_on_spikes optimizer wrapper."""
435+
436+
def _run_spike_test(self, spike_kwargs):
437+
inner_opt = optax.sgd(0.1)
438+
opt = optimizers.skip_step_on_spikes(inner_opt, interval=4, scaling_factor=1.0)
439+
440+
params = {"x": jnp.array([1.0])}
441+
opt_state = opt.init(params)
442+
443+
# Base kwargs for warmup
444+
base_kwargs = {k: jnp.array(1.0) for k in spike_kwargs.keys()}
445+
446+
# Step 0: count = 0 < 2, will not skip (count should be >= interval / 2)
447+
updates, opt_state = opt.update({"x": jnp.array([1.0])}, opt_state, params, **base_kwargs)
448+
self.assertFalse(jnp.all(updates["x"] == 0.0))
449+
450+
# Step 1: count = 1 < 2, will not skip. mean=1.0, std=0.0 (count should be >= interval / 2)
451+
updates, opt_state = opt.update({"x": jnp.array([1.0])}, opt_state, params, **base_kwargs)
452+
self.assertFalse(jnp.all(updates["x"] == 0.0))
453+
454+
# Step 2: count = 2. Spike!
455+
spike_kwargs_jnp = {k: jnp.array(v) for k, v in spike_kwargs.items()}
456+
updates, opt_state = opt.update({"x": jnp.array([1.0])}, opt_state, params, **spike_kwargs_jnp)
457+
self.assertTrue(jnp.all(updates["x"] == 0.0))
458+
459+
def test_skip_step_on_loss_spike(self):
460+
self._run_spike_test({"loss": 100.0})
461+
462+
def test_skip_step_on_grad_norm_spike(self):
463+
self._run_spike_test({"loss": 1.0, "grad_norm": 100.0})
464+
465+
def test_skip_step_on_both_spike(self):
466+
self._run_spike_test({"loss": 100.0, "grad_norm": 100.0})
467+
468+
def test_no_skip_without_kwargs(self):
469+
inner_opt = optax.sgd(0.1)
470+
opt = optimizers.skip_step_on_spikes(inner_opt, interval=4, scaling_factor=1.0)
471+
472+
params = {"x": jnp.array([1.0])}
473+
opt_state = opt.init(params)
474+
475+
# Missing kwargs should act normally
476+
updates, opt_state = opt.update({"x": jnp.array([1.0])}, opt_state, params)
477+
self.assertFalse(jnp.all(updates["x"] == 0.0))
478+
# Count shouldn't have incremented
479+
self.assertEqual(opt_state["count"], 0)
480+
481+
431482
if __name__ == "__main__":
432483
unittest.main()

0 commit comments

Comments
 (0)