|
20 | 20 | import jax.numpy as jnp |
21 | 21 |
|
22 | 22 | import optax |
| 23 | +from maxtext.utils import max_logging |
23 | 24 | from optax.contrib._muon import muon |
24 | 25 | from maxtext.utils.muon_utils import get_muon_weight_dimension_numbers |
25 | 26 |
|
@@ -48,6 +49,124 @@ def get_adamw_mask(config): |
48 | 49 | return _get_path_mask_fn(getattr(config, "adamw_mask", None), match_returns_true=False) |
49 | 50 |
|
50 | 51 |
|
| 52 | +def _compute_rolling_stats(arr: jax.Array, count: jax.Array, interval: int): |
| 53 | + """Computes mean and unbiased std (Bessel's correction) over a rolling window.""" |
| 54 | + valid_elements = jnp.minimum(count, interval) |
| 55 | + safe_elements = jnp.maximum(1, valid_elements) |
| 56 | + mask = jnp.arange(interval) < valid_elements |
| 57 | + |
| 58 | + mean = jnp.sum(jnp.where(mask, arr, 0.0)) / safe_elements |
| 59 | + sq_diff = jnp.where(mask, (arr - mean) ** 2, 0.0) |
| 60 | + |
| 61 | + # Use Bessel's correction (N - 1) for unbiased variance to align with torch.std |
| 62 | + variance = jnp.sum(sq_diff) / jnp.maximum(1, valid_elements - 1) |
| 63 | + std = jnp.sqrt(variance) |
| 64 | + return mean, std |
| 65 | + |
| 66 | + |
| 67 | +def skip_step_on_spikes( |
| 68 | + inner_opt: optax.GradientTransformation, interval: int, scaling_factor: float |
| 69 | +) -> optax.GradientTransformationExtraArgs: |
| 70 | + """Wrapper that skips updates when loss or grad_norm spike. |
| 71 | +
|
| 72 | + This wrapper calculates a rolling mean and standard deviation (using |
| 73 | + Bessel's correction) over the last `interval` steps for both the loss |
| 74 | + and the gradient norm. If the current step's loss or gradient norm |
| 75 | + exceeds `mean + scaling_factor * std`, the update is zeroed and the |
| 76 | + optimizer state is not advanced, effectively skipping the step. |
| 77 | +
|
| 78 | + Reference implementation: |
| 79 | + https://github.com/allenai/OLMo-core/blob/c757b7c3c15197154c753d883330afbfa4869dcc/src/olmo_core/optim/skip_step_optimizer.py#L12 |
| 80 | +
|
| 81 | + Args: |
| 82 | + inner_opt: The inner Optax gradient transformation to wrap. |
| 83 | + interval: The number of recent steps to use for calculating mean and std. |
| 84 | + scaling_factor: The multiplier for standard deviation to set the spike threshold. |
| 85 | +
|
| 86 | + Returns: |
| 87 | + An optax.GradientTransformationExtraArgs that skips spikes. |
| 88 | + """ |
| 89 | + |
| 90 | + def init_fn(params): |
| 91 | + return { |
| 92 | + "inner_state": inner_opt.init(params), |
| 93 | + "losses": jnp.zeros(interval, dtype=jnp.float32), |
| 94 | + "grad_norms": jnp.zeros(interval, dtype=jnp.float32), |
| 95 | + "count": jnp.zeros((), dtype=jnp.int32), |
| 96 | + "is_skipped": jnp.array(False, dtype=jnp.bool_), |
| 97 | + } |
| 98 | + |
| 99 | + def update_fn(updates, state, params=None, **extra_args): |
| 100 | + # Using `pop()` removes `loss` and `grad_norm` from `extra_args` before they are |
| 101 | + # passed downstream to `inner_opt.update()`. This prevents `TypeError` if the |
| 102 | + # inner optimizer doesn't explicitly accept these as `kwargs`. |
| 103 | + loss = extra_args.pop("loss", None) |
| 104 | + grad_norm = extra_args.pop("grad_norm", None) |
| 105 | + |
| 106 | + # Fallback to standard update if loss is not provided |
| 107 | + if loss is None: |
| 108 | + inner_updates, new_inner_state = inner_opt.update(updates, state["inner_state"], params, **extra_args) |
| 109 | + return inner_updates, { |
| 110 | + "inner_state": new_inner_state, |
| 111 | + "losses": state["losses"], |
| 112 | + "grad_norms": state["grad_norms"], |
| 113 | + "count": state["count"], |
| 114 | + "is_skipped": jnp.array(False, dtype=jnp.bool_), |
| 115 | + } |
| 116 | + |
| 117 | + count = state["count"] |
| 118 | + losses = state["losses"] |
| 119 | + grad_norms = state["grad_norms"] |
| 120 | + |
| 121 | + # Compute rolling stats |
| 122 | + loss_mean, loss_std = _compute_rolling_stats(losses, count, interval) |
| 123 | + grad_norm_mean, grad_norm_std = _compute_rolling_stats(grad_norms, count, interval) |
| 124 | + |
| 125 | + # Check if the current metrics are within the allowed thresholds |
| 126 | + is_loss_ok = (loss - loss_mean) <= scaling_factor * loss_std |
| 127 | + if grad_norm is not None: |
| 128 | + is_grad_norm_ok = (grad_norm - grad_norm_mean) <= scaling_factor * grad_norm_std |
| 129 | + is_ok = jnp.logical_and(is_loss_ok, is_grad_norm_ok) |
| 130 | + else: |
| 131 | + is_ok = is_loss_ok |
| 132 | + |
| 133 | + # Only enforce skip if we have at least half the interval filled (or 2 elements minimum) |
| 134 | + min_history = max(2, interval // 2) |
| 135 | + is_warmup = (count + 1) < min_history |
| 136 | + is_ok = jnp.logical_or(is_warmup, is_ok) |
| 137 | + |
| 138 | + # Conditionally execute the inner optimizer to prevent momentum poisoning |
| 139 | + def do_update(): |
| 140 | + return inner_opt.update(updates, state["inner_state"], params, **extra_args) |
| 141 | + |
| 142 | + def skip_update(): |
| 143 | + # use callback to work with jax.jit and jax.lax.cond for logging |
| 144 | + jax.debug.callback(lambda c: max_logging.warning(f"Step {c}: Optimizer step skipped due to spike."), count) |
| 145 | + inner_updates = jax.tree_util.tree_map(jnp.zeros_like, updates) |
| 146 | + return inner_updates, state["inner_state"] |
| 147 | + |
| 148 | + inner_updates, new_inner_state = jax.lax.cond(is_ok, do_update, skip_update) |
| 149 | + |
| 150 | + # Update rolling buffers (append even if skipped so spikes can become the new baseline) |
| 151 | + idx = count % interval |
| 152 | + new_losses = losses.at[idx].set(loss) |
| 153 | + |
| 154 | + new_grad_norms = grad_norms |
| 155 | + if grad_norm is not None: |
| 156 | + new_grad_norms = grad_norms.at[idx].set(grad_norm) |
| 157 | + |
| 158 | + new_state = { |
| 159 | + "inner_state": new_inner_state, |
| 160 | + "losses": new_losses, |
| 161 | + "grad_norms": new_grad_norms, |
| 162 | + "count": count + 1, |
| 163 | + "is_skipped": jnp.logical_not(is_ok), |
| 164 | + } |
| 165 | + return inner_updates, new_state |
| 166 | + |
| 167 | + return optax.GradientTransformationExtraArgs(init_fn, update_fn) |
| 168 | + |
| 169 | + |
51 | 170 | def get_optimizer(config, learning_rate_schedule, model=None): |
52 | 171 | """Create optimizer.""" |
53 | 172 | if config.opt_type == "adamw": |
@@ -100,6 +219,13 @@ def get_optimizer(config, learning_rate_schedule, model=None): |
100 | 219 | else: |
101 | 220 | raise ValueError(f"{config.opt_type=} is not a supported.") |
102 | 221 |
|
| 222 | + if getattr(config, "skip_step_on_spikes", False): |
| 223 | + base_opt = skip_step_on_spikes( |
| 224 | + base_opt, |
| 225 | + interval=config.skip_step_interval, |
| 226 | + scaling_factor=config.skip_step_scaling_factor, |
| 227 | + ) |
| 228 | + |
103 | 229 | # If a whitelist of trainable parameters is provided, freeze everything else. |
104 | 230 | # When trainable_parameters_mask is empty, freeze_mask_fn is None and all parameters are trained. |
105 | 231 | trainable_patterns = getattr(config, "trainable_parameters_mask", None) |
|
0 commit comments