@@ -387,31 +387,6 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
387387 else :
388388 grads = raw_grads
389389
390- # fp8 fix: sanitize NaN OWG (overwrite-with-gradient) stats before apply_gradients.
391- # Under FSDP, the fp8 output gradient amax can be NaN at step 0, which propagates into
392- # amax_history and corrupts future steps. Replace NaN OWG entries with the current state
393- # values (skip the amax update for that step) instead of letting NaN flow through.
394- # Also restore OWG values after apply_gradients to bypass optimizer corruption
395- # (Adam should not update fp8 scale/amax_history).
396- fp8_stats = dict (grads ).get (maxtext_utils .OVERWRITE_WITH_GRADIENT , None )
397- if fp8_stats is not None :
398- if maxtext_utils .OVERWRITE_WITH_GRADIENT in state .params :
399- current_fp8 = state .params [maxtext_utils .OVERWRITE_WITH_GRADIENT ]
400- fp8_stats = jax .tree_util .tree_map (
401- lambda new , cur : jnp .where (jnp .isnan (new ), cur , new ),
402- fp8_stats ,
403- current_fp8 ,
404- )
405- else :
406- fp8_stats = jax .tree_util .tree_map (lambda x : jnp .nan_to_num (x , nan = 0.0 ), fp8_stats )
407- grads = dict (grads )
408- grads [maxtext_utils .OVERWRITE_WITH_GRADIENT ] = fp8_stats
409- # Zero out any remaining NaN in float gradients to prevent param corruption
410- grads = jax .tree_util .tree_map (
411- lambda x : jnp .nan_to_num (x , nan = 0.0 ) if jnp .issubdtype (x .dtype , jnp .floating ) else x ,
412- grads ,
413- )
414-
415390 if config .optimizer_memory_host_offload :
416391 state = state .replace (
417392 opt_state = jax .device_put (
@@ -462,25 +437,7 @@ def move(path, value):
462437 )
463438 else :
464439 new_state = state .apply_gradients (grads = full_grads )
465- # fp8 fix: restore sanitized OWG values, bypassing any optimizer update to fp8 stats.
466- if fp8_stats is not None :
467- new_params = dict (new_state .params )
468- new_params [maxtext_utils .OVERWRITE_WITH_GRADIENT ] = fp8_stats
469- new_state = new_state .replace (params = new_params )
470- has_batch_stats = (
471- config .weight_sparsity_n
472- and config .weight_sparsity_m
473- and bool (aux .get ("batch_stats" ))
474- and isinstance (state .params , dict )
475- and "batch_stats" in state .params
476- )
477440
478- if has_batch_stats :
479- new_params = dict (new_state .params )
480- new_params ["batch_stats" ] = max_utils .unbox_logicallypartioned (
481- aux ["batch_stats" ]
482- )
483- new_state = new_state .replace (params = new_params )
484441 # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family
485442 if config .routed_bias and config .routed_bias_update_rate > 0.0 and moe_bias_updates is not None :
486443 target_path = ("params" , "decoder" , "moe_layers" , "DeepSeekMoeBlock_0" , "MoeBlock_0" , "gate" , "bias" )
0 commit comments