Skip to content

Remove undocumented fp8 NaN sanitization from train_step#3721

Merged
copybara-service[bot] merged 1 commit intomainfrom
fix/gate-fp8-nan-sanitization
Apr 24, 2026
Merged

Remove undocumented fp8 NaN sanitization from train_step#3721
copybara-service[bot] merged 1 commit intomainfrom
fix/gate-fp8-nan-sanitization

Conversation

@ecnal-cienet
Copy link
Copy Markdown
Collaborator

@ecnal-cienet ecnal-cienet commented Apr 22, 2026

Description

The train_step in src/maxtext/trainers/pre_train/train.py contained an fp8-specific NaN handling block with two parts:

  1. An OWG-specific correction — replaces NaN entries in the OVERWRITE_WITH_GRADIENT bucket (AQT's fp8 amax_history / scale) with the previous step's values before apply_gradients, then restores them after to "bypass optimizer corruption."
  2. A blanket jnp.nan_to_num over every float gradient in the tree — zeroes any NaN in any parameter gradient unconditionally.

This PR removes both.

Why

The logic was introduced in commit 6a0f895c3 ("NNX migration: NNX utils") as a drive-by addition inside an unrelated migration commit. The commit message mentions sharding utilities, not fp8, not NaN, not amax_history. No accompanying test, no repro config, no reference to a failure-mode bug report. I tried to find the workload that justified it — searched the anfals/fp8_* branches and the full commit history for amax_history — and came up empty. If a repro exists, it was never checked in.

Flagged the blanket nan_to_num as a concern: it silently zeroes parameter-gradient NaN, which hides convergence bugs behind the appearance of stable training. I agreed and ran an A/B on V6e-8, 10 steps, FSDP=8, synthetic data:

Config Original OWG removed, mask kept Both removed
bf16 gpt3-52k stable stable stable (bit-identical)
fp8 gpt3-52k stable stable (bit-identical) stable (bit-identical)
fp8 llama2-7b stable (was hidden by mask) NaN at step 2 NaN at step 2

The llama2-7b fp8 run exposes a pre-existing convergence issue that the blanket mask was silently hiding. The OWG block has no observable effect on any of these configs. On every workload we can actually test, the entire block is dead code.

Behavior change

  • Non-fp8 recipes: no change in behavior. Code path is simply removed.
  • fp8 recipes that were previously stable: no change — the OWG correction never fired on those anyway.
  • fp8 recipes that appeared stable but had real gradient NaN being zeroed by the blanket mask: will now surface NaN loss and trigger abort_on_nan_loss. This is the intended behavior change — silent masking should not be the default.

If a specific fp8 config genuinely relies on the OWG correction, that should be demonstrated with a repro and fixed properly (ideally upstream in AQT / the fp8 backward pass rather than in the trainer).

Tests

A/B on V6e-8, 10 steps, FSDP=8, synthetic data, attention=dot_product.

  • gpt3-52k bf16 — loss 14.411 → 14.405, stable, no NaN.
  • gpt3-52k fp8 (quantization=fp8 sharding_tolerance=0.10) — loss 14.317 → 14.320, stable, no NaN.
  • llama2-7b fp8 — NaN at step 2, abort_on_nan_loss fires as expected.

Bit-identical results compared to runs with the blocks still present (for configs where the old code was dead); the only behavioral delta is fp8 NaN now surfacing instead of being zeroed.

Repros:

# Non-fp8
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
  model_name=gpt3-52k per_device_batch_size=1 ici_fsdp_parallelism=8 \
  steps=10 dataset_type=synthetic override_model_config=True attention=dot_product

# fp8 (gpt3-52k — stable regardless)
# ...same, plus quantization=fp8 sharding_tolerance=0.10

# fp8 (llama2-7b — surfaces pre-existing NaN)
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
  model_name=llama2-7b per_device_batch_size=1 ici_fsdp_parallelism=8 \
  steps=10 dataset_type=synthetic attention=dot_product \
  quantization=fp8 sharding_tolerance=0.10 max_target_length=1024

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@ecnal-cienet ecnal-cienet changed the title Gate fp8 NaN gradient sanitization on quantization config Fix: Skip fp8 NaN gradient sanitization on non-fp8 workloads Apr 22, 2026
@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 22, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

Comment thread src/maxtext/trainers/pre_train/train.py Outdated
Drop both the OWG bookkeeping correction and the blanket
jnp.nan_to_num over float gradients. Both were introduced as
drive-by additions in an NNX migration commit with no fp8 test,
repro, or justification in the commit message.

A/B on V6e-8, FSDP=8, 10 steps:
- gpt3-52k fp8: bit-identical with or without the blocks.
- llama2-7b fp8: NaN at step 2 either way; the blanket mask was
  previously hiding this as silent zeroed grads.

Real fp8 + FSDP convergence issues should be fixed upstream in
AQT / the fp8 backward pass, not masked in the trainer. Surfacing
the NaN lets us actually investigate it.
@ecnal-cienet ecnal-cienet force-pushed the fix/gate-fp8-nan-sanitization branch from cd5a4f8 to 3d21209 Compare April 22, 2026 23:28
@ecnal-cienet ecnal-cienet changed the title Fix: Skip fp8 NaN gradient sanitization on non-fp8 workloads Remove undocumented fp8 NaN sanitization from train_step Apr 22, 2026
@copybara-service copybara-service Bot merged commit 412902a into main Apr 24, 2026
59 checks passed
@copybara-service copybara-service Bot deleted the fix/gate-fp8-nan-sanitization branch April 24, 2026 21:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants