[NNX] NNX migration (11/N): set pure_nnx / enable_nnx / pure_nnx_decoder defaults to True#3526
[NNX] NNX migration (11/N): set pure_nnx / enable_nnx / pure_nnx_decoder defaults to True#3526ecnal-cienet wants to merge 1 commit into
Conversation
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
bac289f to
db75887
Compare
5a7f63b to
73213e0
Compare
7e33a09 to
2f34cfb
Compare
f4674bb to
b7d1f6d
Compare
9d05b96 to
450ef8d
Compare
e420909 to
8a27207
Compare
2d3b8a6 to
52219b5
Compare
|
🤖 Hi @ecnal-cienet, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @ecnal-cienet, but I was unable to process your request. Please see the logs for more details. |
1 similar comment
|
🤖 I'm sorry @ecnal-cienet, but I was unable to process your request. Please see the logs for more details. |
There was a problem hiding this comment.
📋 Review Summary
This Pull Request successfully transitions MaxText to use JAX's NNX API by default by flipping enable_nnx, pure_nnx, and pure_nnx_decoder to True. The changes are comprehensive and robust, addressing flat NNX state checkpointing, DiLoCo training, parameter-only restoration, and pinning Linen-coupled integration/unit tests.
🔍 General Feedback
- High-Quality Code Migration: The migration of core features (including DiLoCo, param-only generation, and maxengine) to support NNX-native behavior is very well-structured and thoroughly covered by updated tests.
- Robust Sharding Alignments: Adopting
build_zero1_input_state_mesh_shardingsto overlay Param-leaf shardings on the flatnnx.Stateensures seamless compatibility with ZeRO-1 optimizers under NNX. - Preserved Parity: Pinning complex pipeline parallelism and fp8/sparsity tests to the Linen path is a pragmatic decision that avoids regressions while these features are migrated in subsequent iterations.
|
🤖 I'm sorry @ecnal-cienet, but I was unable to process your request. Please see the logs for more details. |
1 similar comment
|
🤖 I'm sorry @ecnal-cienet, but I was unable to process your request. Please see the logs for more details. |
PR6-PR10 promoted every routed-to-Linen feature to NNX-native; PR#2885 lands NNX-native pipeline parallelism. This PR flips the three defaults in base.yml so NNX is the production path, and bundles the NNX-only fixes that surface once pure_nnx=True (DiLoCo merge/checkpoint, Zero-1 input shardings on flat nnx.State, MTP sown-Variable handling, generate_param_only_checkpoint NNX flow, maxengine Linen-parity removal).
NNX Migration Route Map
pure_nnxflag,init_state_fn,TrainStateNNX, NNX utils. Linen workflow unchanged. (PR NNX migration prep (1/N): pure_nnx flag and init_state_fn scaffolding #3427)get_abstract_state_nnx,get_named_sharding_nnx,set_named_sharding_nnx,get_partition_spec_nnx,get_mesh_from_config. (PR NNX migration prep (2/N): NNX utils and sharding utilities #3470)TrainStateNNX, model creation, gradient accumulation, checkpointing, and training loop dispatch. (PR NNX migration prep (3/N): TrainState, model creation, and end-to-end training loop #3500)9.5. ✅ NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix. (PR [NNX] NNX migration prep (9.5/N): NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix #3844)
custom_vjpfor NNX. (PR [NNX] NNX migration prep (10/N): vocab tiling custom_vjp with output-head carve-out #3849)enable_nnx,pure_nnx,pure_nnx_decoderfromFalse→Trueinbase.yml. Bundle the NNX-only fixes that surface oncepure_nnx=True.Description
PR6–PR10 promoted every routed-to-Linen feature to NNX-native; #2885 added NNX-native pipeline parallelism; #4040 added Qwix on NNX. This PR flips the three defaults in
base.ymland bundles the NNX-only fixes that surface oncepure_nnx=True.Changes
src/maxtext/configs/base.yml— flip defaultsenable_nnx: False → True,pure_nnx: False → True,pure_nnx_decoder: False → True.src/maxtext/utils/sharding.py— Zero-1 on flatnnx.Statebuild_zero1_input_state_mesh_shardingsoverlaysParam-leaf shardings on the flatnnx.State. The Linen callstate_mesh_shardings.replace(params=...)only works onTrainState.src/maxtext/trainers/pre_train/train.py,train_compile.py— NNX dispatchpure_nnx=True.Intermediatesown variables before grad so MTP auxiliary losses aren't differentiated as part of the main loss.src/maxtext/trainers/diloco/diloco.py,src/maxtext/common/checkpointing.py— DiLoCo under NNXDiLoCoTrainState.merge/.splitusennx.split, guarded against double-merging.maybe_save_checkpointreadsstate.stepunderenable_diloco, otherwisestate.optimizer.step.replace_nnx_model_paramsidentifies "model" leaves by path viatree_flatten_with_path. Preserves the originaltreedefsolax.condbranches still match, and is robust to future key additions toinner_state(addresses #3526 review comment).src/maxtext/utils/generate_param_only_checkpoint.py— NNX param-only restore{"value": ...}wrapping), opt_state path skipping, bf16 cast skipping rng leaves.src/maxtext/inference/maxengine/maxengine.py— drop Linen-vs-NNX parity assertssrc/maxtext/layers/nnx_wrappers.py— modernizeToLinen.__call___refresh_variable_trace_stateprivate-state workaround; use idiomaticnnx.split/nnx.update/nnx.mergeand filter unknown paths before assignment (addresses #3526 review comment).src/maxtext/layers/quantizations.py— OSS qwix import fix from #4040from qwix._src.utils import flax_util→from qwix._src import flax_util. PR#4040 (Copybara) referenced the Google-internalqwix._src.utilspath; OSS qwix hasflax_utildirectly under_src/. Unblocks OSS CI.src/maxtext/utils/{muon_utils,qk_clip_utils,train_utils}.py— NNX-shape adjustmentsmuon_utils.get_muon_weight_dimension_numbersdispatches by NNX-vs-Linen state shape.qk_clip_utilsbroadcasts over the correct axis under NNX.train_utils.jit_train_stepthreadsdropout_rng=Noneon the NNX path.src/maxtext/trainers/post_train/sft/train_sft_native.py— SFT NNX pathnnx.split(state)beforejit, dropout-rng threaded conditionally (mirrors the pre-train path).src/maxtext/checkpoint_conversion/{to_maxtext.py, utils/utils.py}— shared helperparam_key_parts_from_pathhelper;to_maxtext.pyalso handlesnn.LogicallyPartitionedshape access correctly.Tests
tests/unit/tiling_test.py::LossAndGradientCorrectnessTest— pin to Linen insetUp(builds viatransformer_as_linen); drop 6 stalepytest.skip("vocab tiling on NNX")guards (now NNX-native via PR10).tests/integration/maxengine_test.py— drop Linen-vs-NNX prefill/decode parity tests; NNX-only assertions kept.tests/unit/max_utils_test.py— pinUnscanTestto Linen viainit_pyconfig; drop the threehasattr(state, "model")branches (addresses #3526 review comment).tests/integration/diloco_test.py— NNX training-loop simulation + checkpoint coverage.tests/integration/generate_param_only_checkpoint_test.py— NNX param-only restore coverage.tests/unit/{muon_utils,maxtext_utils,optimizers,state_dtypes,train_state_nnx_checkpoint}_test.py— adjusted forTrainStateNNX/ flatnnx.Stateshapes.tests/unit/qwen3_next_vs_reference_test.py—eps→epsilonrename inQwen3NextRMSNorm_PT+ related cleanup.tests/integration/tokamax_test.py— split parameterized test intogmm_bf16andgmm_fp8cases.Stats
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.