Skip to content

NNX migration prep (3/N): TrainState, model creation, and end-to-end training loop#3500

Open
ecnal-cienet wants to merge 1 commit intomainfrom
feat/nnx-trainstate-and-training-loop
Open

NNX migration prep (3/N): TrainState, model creation, and end-to-end training loop#3500
ecnal-cienet wants to merge 1 commit intomainfrom
feat/nnx-trainstate-and-training-loop

Conversation

@ecnal-cienet
Copy link
Copy Markdown
Collaborator

@ecnal-cienet ecnal-cienet commented Mar 25, 2026

NNX Migration Route Map

  1. ✅ Add NNX scaffolding: pure_nnx flag, init_state_fn, TrainStateNNX, NNX utils. Linen workflow unchanged. (PR #3427)
  2. ✅ NNX sharding utilities: get_abstract_state_nnx, get_named_sharding_nnx, set_named_sharding_nnx, get_partition_spec_nnx, get_mesh_from_config. (PR #3470)
  3. 🔄 [This PR] NNX fully supported end-to-end: TrainStateNNX, model creation, gradient accumulation, checkpointing, and training loop dispatch. (PR #3500)
  4. ✅ NNX sharding diagnostics, bidirectional Linen↔NNX checkpoint conversion utilities, and post-training fixes. (PR #3652)
  5. ❌ Enable NNX by default; fix unit and integration test failures.
  6. ❌ Remove Linen-specific code paths and NNX compatibility flags.

Description

Note: This is the third in a series of NNX migration PRs. With this PR, pure_nnx=True runs a complete NNX training loop — initialization, sharding, gradient accumulation, eval, and checkpointing — without hitting any NotImplementedError. The pure_nnx flag still defaults to False, preserving the existing Linen workflow unchanged.

TrainStateNNX and unit tests

src/maxtext/layers/train_state_nnx.py implements the TrainStateNNX container, which holds an NNX model and its Optax optimizer as a single composable unit. Unit tests cover state creation, optimizer step, and Orbax checkpoint round-trip:

  • tests/unit/train_state_nnx_test.py
  • tests/unit/train_state_nnx_checkpoint_test.py

Muon optimizer and model creation utilities

  • muon_utils.py — updated to support NNX models alongside Linen.
  • model_creation_utils.py — refactored to expose create_nnx_abstract_model and from_config, which create and initialize an NNX model from a config without running a full forward pass.

End-to-end training loop (train.py + supporting modules)

The core training loop in train.py now dispatches on pure_nnx at every major decision point:

  • Sharding (sharding.py) — maybe_update_params_sharding_with_opt dispatches to a new maybe_update_params_sharding_with_opt_nnx, which extracts nnx.Param-only shardings from the flat nnx.State without accessing .params.
  • Gradient accumulation (gradient_accumulation.py) — NNX path uses nnx.value_and_grad with nnx.split / nnx.merge per microbatch inside jax.lax.scan, carrying non-Param rest state (RNGs) through the loop.
  • Train/eval step JIT (maxtext_utils.py) — get_functional_train_with_signature and get_functional_eval_with_signature use a 2-element in_shardings tuple (state, batch) for NNX (no rng argument), vs. 3-element for Linen.
  • Checkpointing (checkpointing.py) — maybe_save_checkpoint converts nnx.State to a plain dict via state.to_pure_dict() before Orbax save; load_state_if_possible restores via nnx.replace_by_pure_dict(abstract_state, dict).

Tests

Unit tests

python3 -m pytest tests/unit/train_state_nnx_test.py -v
python3 -m pytest tests/unit/train_state_nnx_checkpoint_test.py -v
python3 -m pytest tests/unit/maxtext_utils_test.py -v
python3 -m pytest tests/unit/maxtext_utils_nnx_test.py -v
python3 -m pytest tests/unit/muon_utils_test.py -v
python3 -m pytest tests/unit/nnx_decoders_test.py -v
python3 -m pytest tests/unit/train_compile_test.py -v
python3 -m pytest tests/unit/sharding_nnx_test.py -v
python3 -m pytest tests/unit/gradient_accumulation_nnx_test.py -v
python3 -m pytest tests/unit/checkpointing_nnx_load_test.py -v
python3 -m pytest tests/unit/train_nnx_test.py -v
python3 -m pytest tests/unit/train_utils_nnx_test.py -v

Integration tests

python3 -m pytest tests/integration/setup_train_loop_nnx_test.py -v

Coverage of new tests added in this PR

These tests specifically target the NNX patch lines flagged as uncovered in the Codecov report.

Test file Targets Tests
tests/unit/sharding_nnx_test.py sharding.maybe_update_params_sharding_with_opt_nnx — Linen↔NNX dispatch, nnx.Param-only filter, no-op when shard_optimizer_over_data=False, Zero-1 mu propagation, SGD raises NotImplementedError, chained-optimizer recursion 6
tests/unit/gradient_accumulation_nnx_test.py gradient_accumulation.gradient_accumulation_loss_and_grad NNX branch — nnx.split / nnx.value_and_grad / nnx.update paths and bf16 cast under Zero-1 3
tests/unit/checkpointing_nnx_load_test.py checkpointing.load_state_if_possible NNX branch — nnx.split(model, nnx.Param, ...) for param-only restore + Linen passthrough sanity 3
tests/unit/train_nnx_test.py pre_train.train.loss_fn / train_step / eval_step NNX paths — full aux dict, intermediates capture, indexer dense warm-up, vocab-tiling raises, optimizer step increment, gradient clipping, DPO raises 9
tests/unit/train_utils_nnx_test.py train_utils.setup_train_loop NNX patterns — create_train_state_fn closure, nnx.split / nnx.merge round-trip 6
tests/integration/setup_train_loop_nnx_test.py End-to-end setup_train_loop with pure_nnx=True — verifies the returned TrainStateNNX, Param-only split structure matches state_mesh_shardings, DPO NotImplementedError guard 3

Local Test Result

Report

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-trainstate-and-training-loop branch 3 times, most recently from 4bae533 to e6baabd Compare March 25, 2026 21:48
@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 25, 2026

@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-trainstate-and-training-loop branch 2 times, most recently from 754df44 to 8055cc8 Compare March 26, 2026 17:10
@ecnal-cienet ecnal-cienet changed the title Feat/nnx trainstate and training loop NNX migration prep (3/N): Feat/nnx trainstate and training loop Mar 26, 2026
@ecnal-cienet ecnal-cienet changed the title NNX migration prep (3/N): Feat/nnx trainstate and training loop NNX migration prep (3/N): TrainState, model creation, and end-to-end training loop Mar 27, 2026
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-trainstate-and-training-loop branch 2 times, most recently from a906f15 to a726aac Compare March 31, 2026 14:00
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-trainstate-and-training-loop branch 5 times, most recently from 6851939 to cbce870 Compare April 6, 2026 18:33
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-trainstate-and-training-loop branch 6 times, most recently from 6b778c3 to 2fdce03 Compare April 13, 2026 14:58
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-trainstate-and-training-loop branch 4 times, most recently from f4f0d70 to 39d3587 Compare April 16, 2026 17:45
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-trainstate-and-training-loop branch from 39d3587 to 69f71ed Compare April 16, 2026 22:20
Copy link
Copy Markdown
Collaborator

@bvandermoon bvandermoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally looking good to me. Thank you @ecnal-cienet!

Comment thread src/maxtext/common/checkpointing.py
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-trainstate-and-training-loop branch from 29bbe47 to 2c33dc7 Compare April 22, 2026 20:22
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-trainstate-and-training-loop branch 3 times, most recently from 7b2bd5b to 1abe206 Compare April 22, 2026 21:36
@ecnal-cienet
Copy link
Copy Markdown
Collaborator Author

Codecov Report

❌ Patch coverage is 42.95612% with 247 lines in your changes missing coverage. Please review.
Files with missing lines Patch % Lines
src/maxtext/trainers/pre_train/train.py 32.35% 94 Missing and 21 partials ⚠️
src/maxtext/utils/sharding.py 5.55% 50 Missing and 1 partial ⚠️
src/maxtext/utils/maxtext_utils.py 70.66% 15 Missing and 7 partials ⚠️
src/maxtext/utils/train_utils.py 27.58% 19 Missing and 2 partials ⚠️
src/maxtext/utils/gradient_accumulation.py 27.77% 8 Missing and 5 partials ⚠️
src/maxtext/common/checkpointing.py 43.75% 6 Missing and 3 partials ⚠️
src/maxtext/utils/muon_utils.py 66.66% 4 Missing and 3 partials ⚠️
src/maxtext/layers/nnx_decoders.py 0.00% 5 Missing and 1 partial ⚠️
src/maxtext/utils/model_creation_utils.py 93.18% 2 Missing and 1 partial ⚠️
📢 Thoughts on this report? Let us know!

Is it possible to increase the code coverage here?

Thanks — two things going on here, and both have been addressed:

1. Config bug suppressing patch coverage. Most of the "uncovered" lines in this PR aren't actually uncovered; coverage.py never traced them. .coveragerc had [run] source = MaxText but the new code lives in the lower-case maxtext package (the NNX-migrated path). Lines imported via maxtext.* were reported to Codecov with coverage=null — including plain import statements like from flax import nnx in src/maxtext/common/checkpointing.py:23. Fixed in a separate PR that adds maxtext to [run] source: #<PR_NUM_FOR_COVERAGERC>. Once that lands, rebasing this PR should flip ~444 phantom-uncovered lines into real hit/miss status.

2. Targeted unit tests for new NNX lines. Added 34 tests across 4 files, specifically against the diff of this PR:

  • tests/unit/muon_utils_test.pynew file, 26 tests. Covers every branch of transform_logic, both NNX and Linen branches of get_muon_weight_dimension_numbers, and both leaf types in _print_structure_debug (nn.LogicallyPartitioned and jax.ShapeDtypeStruct). File goes from 0 tests → 85.7% local line coverage.
  • tests/unit/maxtext_utils_test.py — +5 tests on the new get_nnx_named_sharding_with_scan_axis helper (scan-axis insertion, deduplication guard, MaskedNode passthrough, empty-pspec, string→tuple wrapping).
  • tests/unit/nnx_decoders_test.py — +1 test (test_multimodal_input_unpacks_into_individual_fields) for the MultimodalInput unpacking block in NNXDecoder.__call__.
  • tests/unit/train_state_nnx_checkpoint_test.py — +5 tests in TestMaybeSaveCheckpointStepAlignment (answering your previous comment).

@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-trainstate-and-training-loop branch 7 times, most recently from 090b545 to ed8c599 Compare April 24, 2026 22:06
- Add TrainStateNNX (layers/train_state_nnx.py) with checkpoint and unit tests
- Refactor model_creation_utils with create_nnx_abstract_model(); add NNX support to muon_utils
- Add get_abstract_state_nnx() and get_nnx_named_sharding_with_scan_axis() to maxtext_utils.py
- Wire NNX train state into train.py and train_utils.py with pure_nnx dispatch
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-trainstate-and-training-loop branch from ed8c599 to 77202a8 Compare April 24, 2026 23:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants