Skip to content

Commit cd5a4f8

Browse files
committed
Gate fp8 NaN gradient sanitization on quantization config
The NaN sanitization introduced for fp8 delayed-scaling FSDP ran on every step regardless of quantization, adding a ~2-3% step-time regression on non-fp8 workloads (per-float-grad jnp.nan_to_num tree_map). The failure mode only occurs under fp8, so gate the block on config.quantization in {"fp8", "fp8_full", "nanoo_fp8"}. Non-fp8 workloads skip the tree_map entirely; fp8 behavior is unchanged (verified: step 1 loss still finite under gpt3-52k + FSDP=8).
1 parent 172d0f1 commit cd5a4f8

1 file changed

Lines changed: 20 additions & 18 deletions

File tree

src/maxtext/trainers/pre_train/train.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -362,24 +362,26 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
362362
# values (skip the amax update for that step) instead of letting NaN flow through.
363363
# Also restore OWG values after apply_gradients to bypass optimizer corruption
364364
# (Adam should not update fp8 scale/amax_history).
365-
fp8_stats = dict(grads).get(maxtext_utils.OVERWRITE_WITH_GRADIENT, None)
366-
if fp8_stats is not None:
367-
if maxtext_utils.OVERWRITE_WITH_GRADIENT in state.params:
368-
current_fp8 = state.params[maxtext_utils.OVERWRITE_WITH_GRADIENT]
369-
fp8_stats = jax.tree_util.tree_map(
370-
lambda new, cur: jnp.where(jnp.isnan(new), cur, new),
371-
fp8_stats,
372-
current_fp8,
373-
)
374-
else:
375-
fp8_stats = jax.tree_util.tree_map(lambda x: jnp.nan_to_num(x, nan=0.0), fp8_stats)
376-
grads = dict(grads)
377-
grads[maxtext_utils.OVERWRITE_WITH_GRADIENT] = fp8_stats
378-
# Zero out any remaining NaN in float gradients to prevent param corruption
379-
grads = jax.tree_util.tree_map(
380-
lambda x: jnp.nan_to_num(x, nan=0.0) if jnp.issubdtype(x.dtype, jnp.floating) else x,
381-
grads,
382-
)
365+
fp8_stats = None
366+
if config.quantization in ("fp8", "fp8_full", "nanoo_fp8"):
367+
fp8_stats = dict(grads).get(maxtext_utils.OVERWRITE_WITH_GRADIENT, None)
368+
if fp8_stats is not None:
369+
if maxtext_utils.OVERWRITE_WITH_GRADIENT in state.params:
370+
current_fp8 = state.params[maxtext_utils.OVERWRITE_WITH_GRADIENT]
371+
fp8_stats = jax.tree_util.tree_map(
372+
lambda new, cur: jnp.where(jnp.isnan(new), cur, new),
373+
fp8_stats,
374+
current_fp8,
375+
)
376+
else:
377+
fp8_stats = jax.tree_util.tree_map(lambda x: jnp.nan_to_num(x, nan=0.0), fp8_stats)
378+
grads = dict(grads)
379+
grads[maxtext_utils.OVERWRITE_WITH_GRADIENT] = fp8_stats
380+
# Zero out any remaining NaN in float gradients to prevent param corruption
381+
grads = jax.tree_util.tree_map(
382+
lambda x: jnp.nan_to_num(x, nan=0.0) if jnp.issubdtype(x.dtype, jnp.floating) else x,
383+
grads,
384+
)
383385

384386
if config.optimizer_memory_host_offload:
385387
state = state.replace(

0 commit comments

Comments
 (0)