Skip to content

NNX migration prep (4/N): sharding tools and Linen<->NNX checkpoint utilities#3525

Closed
ecnal-cienet wants to merge 2 commits intomainfrom
feat/nnx-linen-converter-and-sharding-tools
Closed

NNX migration prep (4/N): sharding tools and Linen<->NNX checkpoint utilities#3525
ecnal-cienet wants to merge 2 commits intomainfrom
feat/nnx-linen-converter-and-sharding-tools

Conversation

@ecnal-cienet
Copy link
Copy Markdown
Collaborator

@ecnal-cienet ecnal-cienet commented Mar 31, 2026

NNX Migration Route Map

  1. ✅ Add NNX scaffolding: pure_nnx flag, init_state_fn, TrainStateNNX, NNX utils. Linen workflow unchanged. (PR #3427)
  2. ✅ NNX sharding utilities: get_abstract_state_nnx, get_named_sharding_nnx, set_named_sharding_nnx, get_partition_spec_nnx, get_mesh_from_config. (PR #3470)
  3. ✅ NNX fully supported end-to-end: TrainStateNNX, model creation, gradient accumulation, checkpointing, and training loop dispatch. (PR #3500)
  4. [This PR] NNX sharding diagnostics and bidirectional Linen↔NNX checkpoint conversion utilities. (PR #3525)
  5. ❌ NNX post-training fixes: MultimodalInput unpacking, scalar LR guard, nested NNX transform workaround.
  6. ❌ Enable NNX by default; fix unit and integration test failures.
  7. ❌ Remove Linen-specific code paths and NNX compatibility flags.

Description

Note: This is the fourth in a series of NNX migration PRs. This PR adds developer tooling to inspect NNX sharding and convert / compare checkpoints across Linen and NNX formats. No training logic is changed.

Sharding diagnostics

  • maxtext_utils.pyprint_shardings_params now dispatches on pure_nnx: for NNX models it iterates over the flat nnx.State rather than the Linen params tree.
  • tests/utils/run_sharding_dump.pyrun_single_dump() now propagates --pure_nnx=true to the sharding-dump subprocess when the flag is set, enabling NNX sharding dumps without manual flag threading.

Linen ↔ NNX checkpoint converter

src/maxtext/checkpoint_conversion/linen_nnx_converter.py — a standalone CPU-only script that bidirectionally converts Orbax checkpoints between Linen and NNX formats.

Key transformations handled:

Direction params tree opt_state step Layer layout
Linen → NNX params/params/<model>model/<model> + {value:} wrappers remove params level from mu/nu move inside optimizer/ stack layers_N arrays → layers tensor (axis 1)
NNX → Linen reverse of above add params level move to top level unstack layers tensor → layers_N per-layer arrays

--direction accepts linen_to_nnx, nnx_to_linen, or auto (detects format from checkpoint keys).

Checkpoint comparison utility

src/maxtext/checkpoint_conversion/compare_linen_nnx_checkpoint.py — compares tree structure, shapes, and optionally values between any two Orbax checkpoints (Linen vs NNX, or same-format). Auto-detects format and applies cross-format normalization (layer axis transposition, {value:} unwrapping, RNG filtering) only when needed.

# Structure + shape comparison (Linen vs NNX)
python compare_linen_nnx_checkpoint.py \
  --ckpt_path_1="gs://bucket/linen_checkpoint/0/items" \
  --ckpt_path_2="gs://bucket/nnx_checkpoint/0/items"

# Value comparison
python compare_linen_nnx_checkpoint.py \
  --ckpt_path_1="gs://bucket/ckpt_a/0/items" \
  --ckpt_path_2="gs://bucket/ckpt_b/0/items" \
  --compare_values --atol=1e-5 --rtol=1e-5

Tests

Unit tests:

python3 -m pytest tests/unit/linen_nnx_converter_test.py -v
python3 -m pytest tests/unit/compare_linen_nnx_checkpoint_test.py -v

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-linen-converter-and-sharding-tools branch 6 times, most recently from bcd7b07 to f27e4f9 Compare April 6, 2026 19:02
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-linen-converter-and-sharding-tools branch 5 times, most recently from 606baf8 to 9895925 Compare April 13, 2026 14:59
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-linen-converter-and-sharding-tools branch 3 times, most recently from d6627ef to 91535ec Compare April 16, 2026 17:46
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-linen-converter-and-sharding-tools branch from 91535ec to 3f34221 Compare April 16, 2026 22:23
- Add TrainStateNNX (layers/train_state_nnx.py) with checkpoint and unit tests
- Refactor model_creation_utils with create_nnx_abstract_model(); add NNX support to muon_utils
- Add get_abstract_state_nnx() and get_nnx_named_sharding_with_scan_axis() to maxtext_utils.py
- Wire NNX train state into train.py and train_utils.py with pure_nnx dispatch
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-linen-converter-and-sharding-tools branch from 3f34221 to 56d4548 Compare April 20, 2026 13:52
…ison utility

- modify print_shardings_params to support NNX (maxtext_utils.py)
- add --pure_nnx flag to run_sharding_dump.py
- add bidirectional Linen<->NNX checkpoint conversion utility (linen_nnx_converter.py)
- add checkpoint comparison utility for Linen vs NNX validation (compare_linen_nnx_checkpoint.py)
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-linen-converter-and-sharding-tools branch from 56d4548 to 07431d2 Compare April 20, 2026 13:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant