Skip to content

NNX migration prep (5/N): enable NNX by default#3526

Draft
ecnal-cienet wants to merge 5 commits intomainfrom
feat/nnx-set-defaults-true
Draft

NNX migration prep (5/N): enable NNX by default#3526
ecnal-cienet wants to merge 5 commits intomainfrom
feat/nnx-set-defaults-true

Conversation

@ecnal-cienet
Copy link
Copy Markdown
Collaborator

@ecnal-cienet ecnal-cienet commented Mar 31, 2026

NNX Migration Route Map

  1. ✅ Add NNX scaffolding: pure_nnx flag, init_state_fn, TrainStateNNX, NNX utils. Linen workflow unchanged. (PR #3427)
  2. ✅ NNX sharding utilities: get_abstract_state_nnx, get_named_sharding_nnx, set_named_sharding_nnx, get_partition_spec_nnx, get_mesh_from_config. (PR #3470)
  3. ✅ NNX fully supported end-to-end: TrainStateNNX, model creation, gradient accumulation, checkpointing, and training loop dispatch. (PR #3500)
  4. ✅ NNX sharding diagnostics, bidirectional Linen↔NNX checkpoint conversion utilities, and post-training fixes. (PR #3652)
  5. 🔄 [This PR] Enable NNX by default; fix unit and integration test failures.
  6. ❌ Remove Linen-specific code paths and NNX compatibility flags.

Description

Note: This is the fifth in a series of NNX migration PRs. This PR flips all three NNX flags to True in base.yml, making NNX the default training path, and fixes the test failures that surface as a result.

Config change

src/MaxText/configs/base.yml — three flags flipped to True:

enable_nnx: True
pure_nnx_decoder: True
pure_nnx: True

Test fixes

File Fix
src/maxtext/layers/nnx_decoders.py Add multimodal_input=None to NNXDecoder.__call__ and unpack into individual fields — Transformer.__call__ passes a unified MultimodalInput object but NNXDecoder previously only accepted the fields individually.
src/maxtext/utils/muon_utils.py Two fixes: (1) In get_muon_weight_dimension_numbers, return the nnx.State directly (preserving .value attribute access for unit tests). (2) In get_model_mdn, normalize the NNX output via nnx.to_pure_dict + {"params": ...} wrapper so test_model_integration expected values (written in Linen format) stay valid.
src/maxtext/trainers/post_train/distillation/distillation_utils.py Guard optimizer_state restore on whether it exists in the checkpoint — PeftTrainer.save() only saves model_params, so restoring optimizer_state unconditionally caused a KeyError.
tests/integration/gradient_accumulation_test.py Switch test_sft_grad_accumulate_same_loss from the deprecated SFT loop to the NNX-native train_sft.py — the deprecated loop always passed nextrng as a 3rd positional arg, mismatching the 2-element NNX in_shardings.
tests/integration/decode_tests.py, generate_param_only_checkpoint_test.py, smoke/inference_microbenchmark_smoke_test.py Add pure_nnx=False enable_nnx=False pure_nnx_decoder=False to all inference test configs — maxengine (the inference engine) does not yet support NNX, so decode tests must explicitly declare the Linen path.
tests/unit/sharding_compare_test.py Filter abstract_state.model leaves to floating-point only before asserting dtype == float32 — the NNX model state includes RNG state variables (uint32/key) that are not weight parameters.
src/maxtext/layers/nnx_decoders.py (scan update) Replace self.layers = nnx.merge(...) with nnx.update(self.layers, nnx.state(...)) in _apply_layers_sequentially callers — reassigning self.layers inside nnx.value_and_grad mutates the NNX graph structure, triggering ValueError: cached_partial graph structure mutated.
src/maxtext/utils/gradient_accumulation.py Fix ZeRO-1 + gradient accumulation (shard_mode=explicit): jax.lax.scan traces its body with an AbstractMesh where all axis types are Auto, which rejects reduced/unreduced PartitionSpec in scan carry tensors. Fix: use plain params_shardings in the scan carry; apply unreduced annotation to gradients after the scan to trigger the all-reduce across data-parallel devices. Also adds copy=True to nnx.merge inside the scan body to avoid TraceContextError from reused Variable objects.
src/maxtext/layers/pipeline.py Fix NNX pipeline parallelism (test_full_train_non_circular): nnx.vmap's extract.to_tree checks Variable._can_update (via _trace_state.is_valid()). Variables created by nnx.merge inside jax.value_and_grad have _trace_state at the grad trace level; when nnx.vmap enters a deeper trace level, _can_update returns False and raises ValueError: Cannot extract graph node from different trace level. Fix: wrap the vmapped function with nnx.to_pure_dict(state) before the nnx.vmap call — a pure dict of arrays has no Variable objects, so extract.to_tree skips the trace check, and nnx.merge(graph, pure_dict) inside the vmap body creates fresh Variables valid at the current trace level.

Tests

python3 -m pytest tests/unit/ tests/integration/ -v

# Checklist

Before submitting this PR, please make sure (put X in square brackets):
- [x] I have performed a self-review of my code. For an optional AI review, add the `gemini-review` label.
- [x] I have necessary comments in my code, particularly in hard-to-understand areas.
- [x] I have run end-to-end tests and provided workload links above if applicable.
- [x] I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in [our documentation](https://maxtext.readthedocs.io/en/latest/development.html#adding-new-documentation-files).

@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-set-defaults-true branch 5 times, most recently from bac289f to db75887 Compare April 6, 2026 21:09
@ecnal-cienet ecnal-cienet changed the title Feat/nnx set defaults true NNX migration prep (5/N): enable NNX by default Apr 6, 2026
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-set-defaults-true branch 14 times, most recently from ab2019b to ffeefdd Compare April 9, 2026 23:29
xibinliu and others added 4 commits April 9, 2026 23:36
- Add utils to manipulate the NNX shardings with abstract state of a
  model
  - also add unit tests for the utils
- Extract mesh creation function to maxtext_utils.get_mesh_from_config()
  - also add unit tests for this func

Note:
flax v0.12 has DeprecationWarning in multiple places:
  - DeprecationWarning: '.value' access is now deprecated. Use
    variable.get_value() or variable[...] (for [Array]).
  - DeprecationWarning: 'VariableState' was removed, this is just
    an alias to 'Variable'. Plase use 'Variable' directly instead.
But since the code needs to work with post-training, which currently
requires flax v0.11, we didn't change code for these warnings.
- Add TrainStateNNX (layers/train_state_nnx.py) with checkpoint and unit tests
- Refactor model_creation_utils with create_nnx_abstract_model(); add NNX support to muon_utils
- Add get_abstract_state_nnx() and get_nnx_named_sharding_with_scan_axis() to maxtext_utils.py
- Wire NNX train state into train.py and train_utils.py with pure_nnx dispatch
…ison utility

- modify print_shardings_params to support NNX (maxtext_utils.py)
- add --pure_nnx flag to run_sharding_dump.py
- add bidirectional Linen<->NNX checkpoint conversion utility (linen_nnx_converter.py)
- add checkpoint comparison utility for Linen vs NNX validation (compare_linen_nnx_checkpoint.py)
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-set-defaults-true branch 2 times, most recently from f1e9765 to 5a7f63b Compare April 9, 2026 23:45
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-set-defaults-true branch from 5a7f63b to 73213e0 Compare April 9, 2026 23:47
@ecnal-cienet ecnal-cienet changed the title NNX migration prep (5/N): enable NNX by default NNX migration prep (6/N): enable NNX by default Apr 16, 2026
@ecnal-cienet ecnal-cienet changed the title NNX migration prep (6/N): enable NNX by default NNX migration prep (5/N): enable NNX by default Apr 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants