@@ -48,6 +48,119 @@ def get_adamw_mask(config):
4848 return _get_path_mask_fn (getattr (config , "adamw_mask" , None ), match_returns_true = False )
4949
5050
51+ def _compute_rolling_stats (arr : jax .Array , count : jax .Array , interval : int ):
52+ """Computes mean and unbiased std (Bessel's correction) over a rolling window."""
53+ valid_elements = jnp .minimum (count , interval )
54+ safe_elements = jnp .maximum (1 , valid_elements )
55+ mask = jnp .arange (interval ) < valid_elements
56+
57+ mean = jnp .sum (jnp .where (mask , arr , 0.0 )) / safe_elements
58+ sq_diff = jnp .where (mask , (arr - mean ) ** 2 , 0.0 )
59+
60+ # Use Bessel's correction (N - 1) for unbiased variance to align with torch.std
61+ variance = jnp .sum (sq_diff ) / jnp .maximum (1 , valid_elements - 1 )
62+ std = jnp .sqrt (variance )
63+ return mean , std
64+
65+
66+ def skip_step_on_spikes (
67+ inner_opt : optax .GradientTransformation , interval : int , scaling_factor : float
68+ ) -> optax .GradientTransformationExtraArgs :
69+ """Wrapper that skips updates when loss or grad_norm spike.
70+
71+ This wrapper calculates a rolling mean and standard deviation (using
72+ Bessel's correction) over the last `interval` steps for both the loss
73+ and the gradient norm. If the current step's loss or gradient norm
74+ exceeds `mean + scaling_factor * std`, the update is zeroed and the
75+ optimizer state is not advanced, effectively skipping the step.
76+
77+ Reference implementation:
78+ https://github.com/allenai/OLMo-core/blob/c757b7c3c15197154c753d883330afbfa4869dcc/src/olmo_core/optim/skip_step_optimizer.py#L12
79+
80+ Args:
81+ inner_opt: The inner Optax gradient transformation to wrap.
82+ interval: The number of recent steps to use for calculating mean and std.
83+ scaling_factor: The multiplier for standard deviation to set the spike threshold.
84+
85+ Returns:
86+ An optax.GradientTransformationExtraArgs that skips spikes.
87+ """
88+
89+ def init_fn (params ):
90+ return {
91+ "inner_state" : inner_opt .init (params ),
92+ "losses" : jnp .zeros (interval , dtype = jnp .float32 ),
93+ "grad_norms" : jnp .zeros (interval , dtype = jnp .float32 ),
94+ "count" : jnp .zeros ((), dtype = jnp .int32 ),
95+ }
96+
97+ def update_fn (updates , state , params = None , ** extra_args ):
98+ # Using `pop()` removes `loss` and `grad_norm` from `extra_args` before they are
99+ # passed downstream to `inner_opt.update()`. This prevents `TypeError` if the
100+ # inner optimizer doesn't explicitly accept these as `kwargs`.
101+ loss = extra_args .pop ("loss" , None )
102+ grad_norm = extra_args .pop ("grad_norm" , None )
103+
104+ # Fallback to standard update if loss is not provided
105+ if loss is None :
106+ inner_updates , new_inner_state = inner_opt .update (updates , state ["inner_state" ], params , ** extra_args )
107+ return inner_updates , {
108+ "inner_state" : new_inner_state ,
109+ "losses" : state ["losses" ],
110+ "grad_norms" : state ["grad_norms" ],
111+ "count" : state ["count" ],
112+ }
113+
114+ count = state ["count" ]
115+ losses = state ["losses" ]
116+ grad_norms = state ["grad_norms" ]
117+
118+ # Compute rolling stats
119+ loss_mean , loss_std = _compute_rolling_stats (losses , count , interval )
120+ grad_norm_mean , grad_norm_std = _compute_rolling_stats (grad_norms , count , interval )
121+
122+ # Check if the current metrics are within the allowed thresholds
123+ is_loss_ok = (loss - loss_mean ) <= scaling_factor * loss_std
124+ if grad_norm is not None :
125+ is_grad_norm_ok = (grad_norm - grad_norm_mean ) <= scaling_factor * grad_norm_std
126+ is_ok = jnp .logical_and (is_loss_ok , is_grad_norm_ok )
127+ else :
128+ is_ok = is_loss_ok
129+
130+ # Only enforce skip if we have at least half the interval filled (or 2 elements minimum)
131+ min_history = max (2 , interval // 2 )
132+ is_warmup = (count + 1 ) < min_history
133+ is_ok = jnp .logical_or (is_warmup , is_ok )
134+
135+ # Conditionally execute the inner optimizer to prevent momentum poisoning
136+ def do_update ():
137+ return inner_opt .update (updates , state ["inner_state" ], params , ** extra_args )
138+
139+ def skip_update ():
140+ inner_updates = jax .tree_util .tree_map (jnp .zeros_like , updates )
141+ return inner_updates , state ["inner_state" ]
142+
143+ inner_updates , new_inner_state = jax .lax .cond (is_ok , do_update , skip_update )
144+
145+ # Update rolling buffers (we append even if skipped so spikes can become the new baseline)
146+ idx = count % interval
147+ new_losses = losses .at [idx ].set (loss )
148+
149+ new_grad_norms = grad_norms
150+ if grad_norm is not None :
151+ new_grad_norms = grad_norms .at [idx ].set (grad_norm )
152+
153+ new_state = {
154+ "inner_state" : new_inner_state ,
155+ "losses" : new_losses ,
156+ "grad_norms" : new_grad_norms ,
157+ "count" : count + 1 ,
158+ }
159+ return inner_updates , new_state
160+
161+ return optax .GradientTransformationExtraArgs (init_fn , update_fn )
162+
163+
51164def get_optimizer (config , learning_rate_schedule , model = None ):
52165 """Create optimizer."""
53166 if config .opt_type == "adamw" :
@@ -100,6 +213,13 @@ def get_optimizer(config, learning_rate_schedule, model=None):
100213 else :
101214 raise ValueError (f"{ config .opt_type = } is not a supported." )
102215
216+ if getattr (config , "skip_step_on_spikes" , False ):
217+ base_opt = skip_step_on_spikes (
218+ base_opt ,
219+ interval = config .skip_step_interval ,
220+ scaling_factor = config .skip_step_scaling_factor ,
221+ )
222+
103223 # If a whitelist of trainable parameters is provided, freeze everything else.
104224 # When trainable_parameters_mask is empty, freeze_mask_fn is None and all parameters are trained.
105225 trainable_patterns = getattr (config , "trainable_parameters_mask" , None )
0 commit comments