Skip to content

Commit 412902a

Browse files
Merge pull request #3721 from AI-Hypercomputer:fix/gate-fp8-nan-sanitization
PiperOrigin-RevId: 905211470
2 parents 49fd452 + 3d21209 commit 412902a

1 file changed

Lines changed: 0 additions & 43 deletions

File tree

src/maxtext/trainers/pre_train/train.py

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -387,31 +387,6 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
387387
else:
388388
grads = raw_grads
389389

390-
# fp8 fix: sanitize NaN OWG (overwrite-with-gradient) stats before apply_gradients.
391-
# Under FSDP, the fp8 output gradient amax can be NaN at step 0, which propagates into
392-
# amax_history and corrupts future steps. Replace NaN OWG entries with the current state
393-
# values (skip the amax update for that step) instead of letting NaN flow through.
394-
# Also restore OWG values after apply_gradients to bypass optimizer corruption
395-
# (Adam should not update fp8 scale/amax_history).
396-
fp8_stats = dict(grads).get(maxtext_utils.OVERWRITE_WITH_GRADIENT, None)
397-
if fp8_stats is not None:
398-
if maxtext_utils.OVERWRITE_WITH_GRADIENT in state.params:
399-
current_fp8 = state.params[maxtext_utils.OVERWRITE_WITH_GRADIENT]
400-
fp8_stats = jax.tree_util.tree_map(
401-
lambda new, cur: jnp.where(jnp.isnan(new), cur, new),
402-
fp8_stats,
403-
current_fp8,
404-
)
405-
else:
406-
fp8_stats = jax.tree_util.tree_map(lambda x: jnp.nan_to_num(x, nan=0.0), fp8_stats)
407-
grads = dict(grads)
408-
grads[maxtext_utils.OVERWRITE_WITH_GRADIENT] = fp8_stats
409-
# Zero out any remaining NaN in float gradients to prevent param corruption
410-
grads = jax.tree_util.tree_map(
411-
lambda x: jnp.nan_to_num(x, nan=0.0) if jnp.issubdtype(x.dtype, jnp.floating) else x,
412-
grads,
413-
)
414-
415390
if config.optimizer_memory_host_offload:
416391
state = state.replace(
417392
opt_state=jax.device_put(
@@ -462,25 +437,7 @@ def move(path, value):
462437
)
463438
else:
464439
new_state = state.apply_gradients(grads=full_grads)
465-
# fp8 fix: restore sanitized OWG values, bypassing any optimizer update to fp8 stats.
466-
if fp8_stats is not None:
467-
new_params = dict(new_state.params)
468-
new_params[maxtext_utils.OVERWRITE_WITH_GRADIENT] = fp8_stats
469-
new_state = new_state.replace(params=new_params)
470-
has_batch_stats = (
471-
config.weight_sparsity_n
472-
and config.weight_sparsity_m
473-
and bool(aux.get("batch_stats"))
474-
and isinstance(state.params, dict)
475-
and "batch_stats" in state.params
476-
)
477440

478-
if has_batch_stats:
479-
new_params = dict(new_state.params)
480-
new_params["batch_stats"] = max_utils.unbox_logicallypartioned(
481-
aux["batch_stats"]
482-
)
483-
new_state = new_state.replace(params=new_params)
484441
# Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family
485442
if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None:
486443
target_path = ("params", "decoder", "moe_layers", "DeepSeekMoeBlock_0", "MoeBlock_0", "gate", "bias")

0 commit comments

Comments
 (0)