Skip to content

[NNX] NNX migration prep (4/N): sharding tools, Linen↔NNX checkpoint utilities, and post-training fixes#3652

Draft
ecnal-cienet wants to merge 3 commits intomainfrom
feat/nnx-post-train-fixes
Draft

[NNX] NNX migration prep (4/N): sharding tools, Linen↔NNX checkpoint utilities, and post-training fixes#3652
ecnal-cienet wants to merge 3 commits intomainfrom
feat/nnx-post-train-fixes

Conversation

@ecnal-cienet
Copy link
Copy Markdown
Collaborator

@ecnal-cienet ecnal-cienet commented Apr 13, 2026

NNX Migration Route Map

  1. ✅ Add NNX scaffolding: pure_nnx flag, init_state_fn, TrainStateNNX, NNX utils. Linen workflow unchanged. (PR #3427)
  2. ✅ NNX sharding utilities: get_abstract_state_nnx, get_named_sharding_nnx, set_named_sharding_nnx, get_partition_spec_nnx, get_mesh_from_config. (PR #3470)
  3. ✅ NNX fully supported end-to-end: TrainStateNNX, model creation, gradient accumulation, checkpointing, and training loop dispatch. (PR #3500)
  4. 🔄 [This PR] NNX sharding diagnostics, bidirectional Linen↔NNX checkpoint conversion utilities, and post-training fixes.
  5. ❌ Enable NNX by default; fix unit and integration test failures.
  6. ❌ Remove Linen-specific code paths and NNX compatibility flags.

Description

This PR combines two closely related changes — NNX tooling additions and post-training bugfixes — into a single reviewable unit.

Part 1: Sharding diagnostics and Linen↔NNX checkpoint utilities

File Change
src/maxtext/utils/maxtext_utils.py Extend print_shardings_params to support NNX state trees in addition to Linen param dicts.
src/maxtext/tools/run_sharding_dump.py Add --pure_nnx flag to select the NNX sharding path.
src/maxtext/utils/linen_nnx_converter.py (new) Bidirectional Linen↔NNX checkpoint conversion utility: convert a Linen checkpoint to NNX format and vice versa.
src/maxtext/tools/compare_linen_nnx_checkpoint.py (new) Checkpoint comparison tool for validating numerical equivalence between Linen and NNX checkpoints.

Part 2: Post-training bug fixes

File Fix
src/maxtext/models/models.py Transformer.__call__ was passing multimodal_input=MultimodalInput(...) to NNXDecoder, which only accepts individual fields (image_embeddings, image_masks, audio_embeddings, audio_masks, bidirectional_mask). Fixed by unpacking the object at the call site.
src/maxtext/optimizers/optimizers.py adam_pax called learning_rate_fn(count) unconditionally, failing when optax.inject_hyperparams (used by the distillation trainer) passes a pre-evaluated scalar instead of a callable schedule. Fixed with a callable() guard.
src/maxtext/trainers/post_train/sft/train_sft.py Tunix's default PeftTrainer._train_step nests nnx.value_and_grad inside nnx.jit, causing Flax NNX to assign conflicting outer_index values and raising ValueError: The graph structure of a node added to cached_partial was mutated. Fixed by adding MaxTextPeftTrainer subclass that overrides create_train_step_fn to use jax.value_and_grad with an explicit nnx.split/nnx.merge pattern (matching MaxText's pre-training NNX train step).
src/maxtext/trainers/post_train/distillation/train_distill.py Same nested NNX transform issue in MaxTextDistillationTrainer._train_step. Additionally, teacher inference is now run outside value_and_grad (it's frozen via stop_gradient anyway) to avoid tracing it unnecessarily. Fixed with the same jax.value_and_grad + explicit split/merge pattern.
src/maxtext/trainers/post_train/rl/train_rl.py Two RL-specific fixes: (1) JAX 0.9+ with_sharding_constraint asserts instead of reshards under Explicit mesh axes — patched with a try/except AssertionError fallback to jax.sharding.reshard. (2) tpu_inference initializes weights as float32 by default; during weight sync, tunix._apply_dtype_cast was upcasting incoming bfloat16 MaxText weights to float32, causing a dtype mismatch in the ragged paged attention kernel — patched to skip bfloat16→float32 upcasts so synced weights stay bfloat16. Also passes dtype explicitly to vLLM init.
src/maxtext/utils/model_creation_utils.py Guard against None checkpoint metadata in create_nnx_model with a descriptive error message when the checkpoint directory is empty or the save did not complete.

Tests

  • Sharding dump validated on V6e-8 with --pure_nnx flag.
  • Linen↔NNX checkpoint conversion validated via compare_linen_nnx_checkpoint.py on gemma2-2b.
  • Post-training fixes validated during SFT and distillation trial runs on V6e-8.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-post-train-fixes branch from 5430e95 to 86e20ea Compare April 13, 2026 15:02
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-post-train-fixes branch 5 times, most recently from a48bdd5 to 8239dd2 Compare April 16, 2026 17:47
@ecnal-cienet ecnal-cienet changed the title Feat/nnx post train fixes [NNX] Post-training fixes: unpack MultimodalInput, scalar LR guard, and nested NNX transform workaround Apr 16, 2026
@ecnal-cienet ecnal-cienet changed the title [NNX] Post-training fixes: unpack MultimodalInput, scalar LR guard, and nested NNX transform workaround NNX migration prep (5/N): Post-training fixes: unpack MultimodalInput, scalar LR guard, and nested NNX transform workaround Apr 16, 2026
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-post-train-fixes branch 3 times, most recently from d8cde29 to 05aede5 Compare April 16, 2026 22:25
@ecnal-cienet ecnal-cienet changed the title NNX migration prep (5/N): Post-training fixes: unpack MultimodalInput, scalar LR guard, and nested NNX transform workaround NNX migration prep (5/N): Post-training fixes Apr 16, 2026
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-post-train-fixes branch 2 times, most recently from 2fbc8c2 to d88224e Compare April 20, 2026 14:00
@ecnal-cienet ecnal-cienet changed the title NNX migration prep (5/N): Post-training fixes NNX] NNX migration prep (4/N): sharding tools, Linen↔NNX checkpoint utilities, and post-training fixes Apr 20, 2026
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-post-train-fixes branch 2 times, most recently from af8d71f to 6b6e61b Compare April 20, 2026 14:17
@ecnal-cienet ecnal-cienet changed the title NNX] NNX migration prep (4/N): sharding tools, Linen↔NNX checkpoint utilities, and post-training fixes [NNX] NNX migration prep (4/N): sharding tools, Linen↔NNX checkpoint utilities, and post-training fixes Apr 20, 2026
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-post-train-fixes branch 5 times, most recently from 5f067bb to 7f06c99 Compare April 21, 2026 23:57
The NNX-migrated code lives in src/maxtext/ (lower-case) but
.coveragerc only listed the upper-case MaxText package in
[run] source. coverage.py therefore never instrumented the
new files, so added lines showed up on Codecov with no
coverage data and suppressed patch coverage.

The existing [paths] aliasing only merges already-collected
data across filesystems; it does not control what is traced
in the first place.
- Add TrainStateNNX (layers/train_state_nnx.py) with checkpoint and unit tests
- Refactor model_creation_utils with create_nnx_abstract_model(); add NNX support to muon_utils
- Add get_abstract_state_nnx() and get_nnx_named_sharding_with_scan_axis() to maxtext_utils.py
- Wire NNX train state into train.py and train_utils.py with pure_nnx dispatch
…raining fixes

Part 1 — sharding diagnostics and Linen<->NNX checkpoint utilities:
- modify print_shardings_params to support NNX (maxtext_utils.py)
- add --pure_nnx flag to run_sharding_dump.py
- add bidirectional Linen<->NNX checkpoint conversion utility (linen_nnx_converter.py)
- add checkpoint comparison utility for Linen vs NNX validation (compare_linen_nnx_checkpoint.py)

Part 2 — post-training bug fixes:
- models.py: unpack MultimodalInput before passing to NNXDecoder (was passing the
  whole object as multimodal_input= kwarg; NNXDecoder only accepts individual fields)
- optimizers.py: guard adam_pax against scalar LR from optax.inject_hyperparams
  (callable() check before invoking learning_rate_fn)
- train_distill.py: fix nested NNX transform issue (nnx.value_and_grad inside nnx.jit
  raises conflicting outer_index error); refactored to jax.value_and_grad + explicit
  nnx.split/merge pattern; teacher inference moved outside value_and_grad
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-post-train-fixes branch from 7f06c99 to 28b1e4a Compare April 23, 2026 18:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant