Remove undocumented fp8 NaN sanitization from train_step#3721
Merged
copybara-service[bot] merged 1 commit intomainfrom Apr 24, 2026
Merged
Remove undocumented fp8 NaN sanitization from train_step#3721copybara-service[bot] merged 1 commit intomainfrom
copybara-service[bot] merged 1 commit intomainfrom
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
tdophung
reviewed
Apr 22, 2026
Drop both the OWG bookkeeping correction and the blanket jnp.nan_to_num over float gradients. Both were introduced as drive-by additions in an NNX migration commit with no fp8 test, repro, or justification in the commit message. A/B on V6e-8, FSDP=8, 10 steps: - gpt3-52k fp8: bit-identical with or without the blocks. - llama2-7b fp8: NaN at step 2 either way; the blanket mask was previously hiding this as silent zeroed grads. Real fp8 + FSDP convergence issues should be fixed upstream in AQT / the fp8 backward pass, not masked in the trainer. Surfacing the NaN lets us actually investigate it.
cd5a4f8 to
3d21209
Compare
gobbleturk
approved these changes
Apr 23, 2026
bvandermoon
approved these changes
Apr 23, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
The
train_stepinsrc/maxtext/trainers/pre_train/train.pycontained an fp8-specific NaN handling block with two parts:OVERWRITE_WITH_GRADIENTbucket (AQT's fp8amax_history/scale) with the previous step's values beforeapply_gradients, then restores them after to "bypass optimizer corruption."jnp.nan_to_numover every float gradient in the tree — zeroes any NaN in any parameter gradient unconditionally.This PR removes both.
Why
The logic was introduced in commit
6a0f895c3("NNX migration: NNX utils") as a drive-by addition inside an unrelated migration commit. The commit message mentions sharding utilities, not fp8, not NaN, notamax_history. No accompanying test, no repro config, no reference to a failure-mode bug report. I tried to find the workload that justified it — searched theanfals/fp8_*branches and the full commit history foramax_history— and came up empty. If a repro exists, it was never checked in.Flagged the blanket
nan_to_numas a concern: it silently zeroes parameter-gradient NaN, which hides convergence bugs behind the appearance of stable training. I agreed and ran an A/B on V6e-8, 10 steps, FSDP=8, synthetic data:gpt3-52kgpt3-52kllama2-7bstable(was hidden by mask)The llama2-7b fp8 run exposes a pre-existing convergence issue that the blanket mask was silently hiding. The OWG block has no observable effect on any of these configs. On every workload we can actually test, the entire block is dead code.
Behavior change
abort_on_nan_loss. This is the intended behavior change — silent masking should not be the default.If a specific fp8 config genuinely relies on the OWG correction, that should be demonstrated with a repro and fixed properly (ideally upstream in AQT / the fp8 backward pass rather than in the trainer).
Tests
A/B on V6e-8, 10 steps, FSDP=8, synthetic data,
attention=dot_product.gpt3-52kbf16 — loss 14.411 → 14.405, stable, no NaN.gpt3-52kfp8 (quantization=fp8 sharding_tolerance=0.10) — loss 14.317 → 14.320, stable, no NaN.llama2-7bfp8 — NaN at step 2,abort_on_nan_lossfires as expected.Bit-identical results compared to runs with the blocks still present (for configs where the old code was dead); the only behavioral delta is fp8 NaN now surfacing instead of being zeroed.
Repros:
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.