NNX migration prep (3/N): TrainState, model creation, and end-to-end training loop#3500
NNX migration prep (3/N): TrainState, model creation, and end-to-end training loop#3500ecnal-cienet wants to merge 1 commit intomainfrom
Conversation
4bae533 to
e6baabd
Compare
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
754df44 to
8055cc8
Compare
a906f15 to
a726aac
Compare
6851939 to
cbce870
Compare
6b778c3 to
2fdce03
Compare
f4f0d70 to
39d3587
Compare
39d3587 to
69f71ed
Compare
8bb919b to
29bbe47
Compare
bvandermoon
left a comment
There was a problem hiding this comment.
Generally looking good to me. Thank you @ecnal-cienet!
Is it possible to increase the code coverage here? |
29bbe47 to
2c33dc7
Compare
7b2bd5b to
1abe206
Compare
Thanks — two things going on here, and both have been addressed: 1. Config bug suppressing patch coverage. Most of the "uncovered" lines in this PR aren't actually uncovered; coverage.py never traced them. 2. Targeted unit tests for new NNX lines. Added 34 tests across 4 files, specifically against the diff of this PR:
|
090b545 to
ed8c599
Compare
- Add TrainStateNNX (layers/train_state_nnx.py) with checkpoint and unit tests - Refactor model_creation_utils with create_nnx_abstract_model(); add NNX support to muon_utils - Add get_abstract_state_nnx() and get_nnx_named_sharding_with_scan_axis() to maxtext_utils.py - Wire NNX train state into train.py and train_utils.py with pure_nnx dispatch
ed8c599 to
77202a8
Compare
NNX Migration Route Map
pure_nnxflag,init_state_fn,TrainStateNNX, NNX utils. Linen workflow unchanged. (PR #3427)get_abstract_state_nnx,get_named_sharding_nnx,set_named_sharding_nnx,get_partition_spec_nnx,get_mesh_from_config. (PR #3470)TrainStateNNX, model creation, gradient accumulation, checkpointing, and training loop dispatch. (PR #3500)Description
TrainStateNNXand unit testssrc/maxtext/layers/train_state_nnx.pyimplements theTrainStateNNXcontainer, which holds an NNX model and its Optax optimizer as a single composable unit. Unit tests cover state creation, optimizer step, and Orbax checkpoint round-trip:tests/unit/train_state_nnx_test.pytests/unit/train_state_nnx_checkpoint_test.pyMuon optimizer and model creation utilities
muon_utils.py— updated to support NNX models alongside Linen.model_creation_utils.py— refactored to exposecreate_nnx_abstract_modelandfrom_config, which create and initialize an NNX model from a config without running a full forward pass.End-to-end training loop (
train.py+ supporting modules)The core training loop in
train.pynow dispatches onpure_nnxat every major decision point:sharding.py) —maybe_update_params_sharding_with_optdispatches to a newmaybe_update_params_sharding_with_opt_nnx, which extractsnnx.Param-only shardings from the flatnnx.Statewithout accessing.params.gradient_accumulation.py) — NNX path usesnnx.value_and_gradwithnnx.split/nnx.mergeper microbatch insidejax.lax.scan, carrying non-Paramreststate (RNGs) through the loop.maxtext_utils.py) —get_functional_train_with_signatureandget_functional_eval_with_signatureuse a 2-elementin_shardingstuple(state, batch)for NNX (no rng argument), vs. 3-element for Linen.checkpointing.py) —maybe_save_checkpointconvertsnnx.Stateto a plain dict viastate.to_pure_dict()before Orbax save;load_state_if_possiblerestores viannx.replace_by_pure_dict(abstract_state, dict).Tests
Unit tests
Integration tests
Coverage of new tests added in this PR
These tests specifically target the NNX patch lines flagged as uncovered in the Codecov report.
tests/unit/sharding_nnx_test.pysharding.maybe_update_params_sharding_with_opt_nnx— Linen↔NNX dispatch,nnx.Param-only filter, no-op whenshard_optimizer_over_data=False, Zero-1 mu propagation, SGD raisesNotImplementedError, chained-optimizer recursiontests/unit/gradient_accumulation_nnx_test.pygradient_accumulation.gradient_accumulation_loss_and_gradNNX branch —nnx.split/nnx.value_and_grad/nnx.updatepaths and bf16 cast under Zero-1tests/unit/checkpointing_nnx_load_test.pycheckpointing.load_state_if_possibleNNX branch —nnx.split(model, nnx.Param, ...)for param-only restore + Linen passthrough sanitytests/unit/train_nnx_test.pypre_train.train.loss_fn/train_step/eval_stepNNX paths — full aux dict, intermediates capture, indexer dense warm-up, vocab-tiling raises, optimizer step increment, gradient clipping, DPO raisestests/unit/train_utils_nnx_test.pytrain_utils.setup_train_loopNNX patterns —create_train_state_fnclosure,nnx.split/nnx.mergeround-triptests/integration/setup_train_loop_nnx_test.pysetup_train_loopwithpure_nnx=True— verifies the returnedTrainStateNNX, Param-only split structure matchesstate_mesh_shardings, DPONotImplementedErrorguardLocal Test Result
Report
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.