Skip to content

[NNX] NNX migration prep (9/N): NNX-aware QK-Clip + checkpoint utilities#3836

Merged
copybara-service[bot] merged 1 commit into
mainfrom
feat/nnx-qk-clip-and-checkpoint-utils
Jun 9, 2026
Merged

[NNX] NNX migration prep (9/N): NNX-aware QK-Clip + checkpoint utilities#3836
copybara-service[bot] merged 1 commit into
mainfrom
feat/nnx-qk-clip-and-checkpoint-utils

Conversation

@ecnal-cienet

@ecnal-cienet ecnal-cienet commented May 7, 2026

Copy link
Copy Markdown
Collaborator

NNX Migration Route Map

  1. ✅ pure_nnx flag + init_state_fn scaffolding (PR NNX migration prep (1/N): pure_nnx flag and init_state_fn scaffolding #3427)
  2. ✅ NNX sharding utilities (PR NNX migration prep (2/N): NNX utils and sharding utilities #3470)
  3. ✅ TrainStateNNX + end-to-end training loop (PR NNX migration prep (3/N): TrainState, model creation, and end-to-end training loop #3500)
  4. ✅ Sharding diagnostics + post-training bugfixes (PR [NNX] NNX migration prep (4/N): sharding tools and post-training fixes #3652)
    4.5. ✅ Linen↔NNX checkpoint conversion utilities (PR [NNX] NNX migration prep (4.5/N): Linen<->NNX checkpoint converter #3843)
  5. ✅ NNX correctness fixes + vocab tiling on NNX (PR [NNX] NNX migration prep (5/N): correctness fixes and feature enablements #3766)
  6. ✅ NNX-native DPO
  7. ✅ NNX-native MaxEngine inference (PR [NNX] NNX migration prep (7/N): NNX-native MaxEngine inference #3821)
  8. ✅ NNX-native LoRA + GRPO (PR [NNX] NNX migration prep (8/N): NNX native lora grpo #3824)
  9. 🔄 [This PR] NNX-aware QK-Clip + remaining NNX-format checkpoint utilities
    9.5. ❌ NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix (stacked follow-up)
  10. ❌ Vocab tiling custom_vjp for NNX
  11. ❌ Set NNX defaults to True
  12. ❌ Delete Linen-specific code paths

Description

Closes the QK-Clip TODO and migrates three Linen-only checkpoint helpers (standalone_checkpointer, generate_param_only_checkpoint, convert_gpt3_ckpt_from_paxml) to NNX. Every NNX edit is gated on config.pure_nnx; Linen paths preserved byte-for-byte.

  • qk_clip_utils.apply_qk_clip_nnx mutates state.model in place via nnx.split → pure-dict tree_map → nnx.replace_by_pure_dictnnx.update.
  • train.py::train_step dispatches to apply_qk_clip_nnx on NNX. TODO at the call site is gone.
  • standalone_checkpointer.add_entropy_to_checkpoint dispatches across Linen TrainState, NNX TrainStateNNX, and post-split nnx.State.
  • generate_param_only_checkpoint gets NNX _possibly_unroll_params_nnx + _save_decode_checkpoint_nnx, plus parallel LoRA decode flow.
  • convert_gpt3_ckpt_from_paxml keystr_map uses the dict-style format jax.tree_util.keystr actually produces on nnx.State: ['optimizer']['step'].value, ['optimizer']['opt_state'][0]['mu']<rest>.value (dict for State, [0] for the optax tuple, .value for the Variable leaf).
  • PR [NNX] NNX migration prep (8/N): NNX native lora grpo #3824 was merged before its bot review could be addressed. Both bugs are real and only hit the NNX GRPO path:
    1. setup_train_loop called TrainStateNNX(model, optimizer, reference_model=...) — but __init__ only accepts (model, optimizer). Fixed by setting state.reference_model as a sibling attribute after construction.
    2. _train_step_nnx::diff_wrapper closed over state.reference_model (an nnx.Module) inside jax.value_and_grad. Fixed by nnx.split-ing the reference outside and passing ref_state as an explicit arg, matching the existing policy-model pattern in the same function.

Tests

  • qk_clip_test: 7 new NNX cases (attention-type guard, MLA wq_b/wkv_b math, both intermediate shapes, no-clip-below-threshold, missing-stats resilience, Linen↔NNX numeric parity).
  • standalone_checkpointer_nnx_test (new): 3 cases.
  • generate_param_only_checkpoint_nnx_test (new): 3 cases (Llama-style, DeepSeek-style, LoRA delta unroll).
  • Existing Linen tests untouched; pure_nnx=False stays default.

Results: 31 passed, 2 skipped (qk_clip_test + 2 new files + grpo_nnx_test).

Stats

+999 / −232 across 9 files (2 new, 7 modified).

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented May 7, 2026

Copy link
Copy Markdown

@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-qk-clip-and-checkpoint-utils branch 2 times, most recently from 68eb7ce to 02ff5f7 Compare May 7, 2026 20:12
@ecnal-cienet ecnal-cienet changed the title [NNX] NNX migration prep (9/N): NNX-aware QK-Clip + checkpoint utilities + NNX-AQT [NNX] NNX migration prep (9/N): NNX-aware QK-Clip + checkpoint utilities May 7, 2026
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-qk-clip-and-checkpoint-utils branch from 02ff5f7 to b5fd654 Compare May 7, 2026 21:53
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-qk-clip-and-checkpoint-utils branch 12 times, most recently from 2a7775a to 6748af8 Compare May 14, 2026 22:51
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-qk-clip-and-checkpoint-utils branch 9 times, most recently from 99b7f9d to ee99e98 Compare May 22, 2026 21:10
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-qk-clip-and-checkpoint-utils branch from ee99e98 to a4b9db9 Compare May 25, 2026 15:26
@github-actions

github-actions Bot commented Jun 2, 2026

Copy link
Copy Markdown

🤖 Hi @ecnal-cienet, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions

github-actions Bot commented Jun 2, 2026

Copy link
Copy Markdown

🤖 I'm sorry @ecnal-cienet, but I was unable to process your request. Please see the logs for more details.

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request successfully migrates QK-Clip and key checkpointing utilities to the NNX-based training and inference paths. The implementation maintains compatibility with existing Linen workflows and adds comprehensive unit tests for the new NNX-aware logic. However, there is a critical runtime error in the GRPO trainer initialization and significant mapping issues in the PaxML checkpoint conversion script that must be addressed.

🔍 General Feedback

  • Functional Correctness: A critical bug was found in grpo_trainer.py where TrainStateNNX is called with an unsupported keyword argument.
  • NNX Best Practices: The use of nnx.split, nnx.merge, and nnx.update for in-place mutation of model states is well-implemented across qk_clip_utils.py and lora_utils.py.
  • Checkpointing Compatibility: The addition of _save_decode_checkpoint_nnx ensures that checkpoints generated from NNX runs are in the correct format for downstream inference consumers.
  • Testing: Excellent addition of NNX-specific unit tests, particularly for the complex scanned layer unrolling and LoRA delta application logic.

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

Additional inline feedback focusing on functional correctness in the GRPO trainer and state mapping in the checkpoint conversion script.

Comment thread src/maxtext/experimental/rl/grpo_trainer.py Outdated
Comment thread src/maxtext/experimental/rl/grpo_trainer.py Outdated
Comment thread src/maxtext/utils/standalone_checkpointer.py
Migrates the remaining Linen-only utilities to NNX (QK-Clip, three
checkpoint helpers) and picks up two real NNX-GRPO bugs flagged on the
already-merged PR8 review. Every NNX edit is gated on
`config.pure_nnx` or runtime state-shape detection; Linen paths are
preserved byte-for-byte.

QK-Clip:
- `qk_clip_utils.apply_qk_clip_nnx` mutates `state.model` in place via
  `nnx.split` -> pure-dict `tree_map` -> `nnx.replace_by_pure_dict` ->
  `nnx.update`. Accepts both the production NNX intermediate shape
  (`self_attention.attention_op.max_logits`, sown inside `AttentionOp`)
  and the synthetic-fixture shape used by the existing Linen tests
  (`self_attention.max_logits`).
- `train.py::train_step` dispatches on `isinstance(model, nn.Module)`
  to call `apply_qk_clip` (Linen) or `apply_qk_clip_nnx` (NNX). The
  TODO at the QK-Clip call site is removed.

NNX-format checkpoint utilities:
- `standalone_checkpointer.checkpoint_loop` builds an NNX
  `init_state_fn` under `pure_nnx` (mirroring PR8's GRPO trainer).
  `add_entropy_to_checkpoint` dispatches across Linen `TrainState`,
  NNX `TrainStateNNX` Module, and post-split `nnx.State` shapes; all
  three produce identical `cos(1000*p)`/`sin(1000*p)` mu/nu
  replacements.
- `generate_param_only_checkpoint`: `_read_train_checkpoint` builds an
  NNX `init_state_fn` under `pure_nnx`. New `_possibly_unroll_params_nnx`
  slices scanned NNX layers via dict-style mutation on
  `state.model.decoder`. New `_save_decode_checkpoint_nnx` writes a
  bf16 pure-dict tree to orbax. Parallel LoRA decode flow operates on
  the single-nested LoRA delta tree from PR8's
  `get_lora_abstract_state_nnx`.
- `convert_gpt3_ckpt_from_paxml`: parallel NNX `state_map` keystr
  translation. Paths use the dict-style format that
  `jax.tree_util.keystr` actually produces on an `nnx.State` — e.g.
  `['optimizer']['step'].value` and
  `['optimizer']['opt_state'][0]['mu']<rest>.value` (dict-style for
  the State Mappings, `[0]` for the optax tuple, `.value` for the
  `nnx.Variable` leaf). Save uses `state.optimizer.step.value` for
  the step number on NNX. End-to-end paxml -> NNX conversion is wired
  but not yet validated on hardware.

NNX-GRPO bug fixes from merged PR8:
- `setup_train_loop`'s NNX `init_state_fn` called
  `TrainStateNNX(nnx_model, optimizer, reference_model=...)` — but
  `TrainStateNNX.__init__` only accepts `(model, optimizer)`, so this
  would raise `TypeError` the first time GRPO ran with
  `pure_nnx=True`. (The original code masked the failure with
  `# pylint: disable-next=unexpected-keyword-arg` rather than running
  it.) Fixed by constructing `TrainStateNNX` with the two valid args
  and setting `state.reference_model` as a sibling attribute after
  construction. `nnx.Module` is mutable, and the attribute survives
  `nnx.split` / `nnx.merge` round-trips.
- `_train_step_nnx`'s `diff_wrapper` closed over `state.reference_model`
  directly inside `jax.value_and_grad`. `nnx.Module` is not a
  registered JAX pytree, so closure-capture only worked as long as
  JAX treated the module as static — fragile, and any internal-state
  touch during the reference forward would trace badly. Fixed by
  mirroring the existing policy-model pattern:
  `nnx.split(state.reference_model)` outside the wrapper, pass
  `ref_state` as an explicit pytree argument into `diff_wrapper`, and
  `nnx.merge` it inside.

Tests:
- `qk_clip_test`: 7 new NNX cases (`QKClipNNXTest`,
  `CalculateMaxLogitNNXTest`) covering attention-type guard, MLA
  wq_b/wkv_b math, both intermediate shapes, no-clip-below-threshold,
  missing-stats resilience, Linen<->NNX numeric parity.
- `standalone_checkpointer_nnx_test` (new): 3 cases for adam mu/nu
  overwrite on `TrainStateNNX` Module shape, no mutation of
  `state.model` params, post-split `nnx.State` shape from
  `setup_training_state`.
- `generate_param_only_checkpoint_nnx_test` (new): 3 cases for
  scanned-layer slicing (Llama-style; DeepSeek-style dense+moe split;
  LoRA delta unroll on the single-nested NNX shape).

Test results: 31 passed, 2 skipped across the PR9 surface
(`qk_clip_test`, `standalone_checkpointer_nnx_test`,
`generate_param_only_checkpoint_nnx_test`, `grpo_nnx_test`).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants