[NNX] NNX migration prep (9/N): NNX-aware QK-Clip + checkpoint utilities#3836
Conversation
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
68eb7ce to
02ff5f7
Compare
02ff5f7 to
b5fd654
Compare
2a7775a to
6748af8
Compare
99b7f9d to
ee99e98
Compare
ee99e98 to
a4b9db9
Compare
fd859ec to
b7d2299
Compare
|
🤖 Hi @ecnal-cienet, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @ecnal-cienet, but I was unable to process your request. Please see the logs for more details. |
There was a problem hiding this comment.
This Pull Request successfully migrates QK-Clip and key checkpointing utilities to the NNX-based training and inference paths. The implementation maintains compatibility with existing Linen workflows and adds comprehensive unit tests for the new NNX-aware logic. However, there is a critical runtime error in the GRPO trainer initialization and significant mapping issues in the PaxML checkpoint conversion script that must be addressed.
🔍 General Feedback
- Functional Correctness: A critical bug was found in
grpo_trainer.pywhereTrainStateNNXis called with an unsupported keyword argument. - NNX Best Practices: The use of
nnx.split,nnx.merge, andnnx.updatefor in-place mutation of model states is well-implemented acrossqk_clip_utils.pyandlora_utils.py. - Checkpointing Compatibility: The addition of
_save_decode_checkpoint_nnxensures that checkpoints generated from NNX runs are in the correct format for downstream inference consumers. - Testing: Excellent addition of NNX-specific unit tests, particularly for the complex scanned layer unrolling and LoRA delta application logic.
Migrates the remaining Linen-only utilities to NNX (QK-Clip, three checkpoint helpers) and picks up two real NNX-GRPO bugs flagged on the already-merged PR8 review. Every NNX edit is gated on `config.pure_nnx` or runtime state-shape detection; Linen paths are preserved byte-for-byte. QK-Clip: - `qk_clip_utils.apply_qk_clip_nnx` mutates `state.model` in place via `nnx.split` -> pure-dict `tree_map` -> `nnx.replace_by_pure_dict` -> `nnx.update`. Accepts both the production NNX intermediate shape (`self_attention.attention_op.max_logits`, sown inside `AttentionOp`) and the synthetic-fixture shape used by the existing Linen tests (`self_attention.max_logits`). - `train.py::train_step` dispatches on `isinstance(model, nn.Module)` to call `apply_qk_clip` (Linen) or `apply_qk_clip_nnx` (NNX). The TODO at the QK-Clip call site is removed. NNX-format checkpoint utilities: - `standalone_checkpointer.checkpoint_loop` builds an NNX `init_state_fn` under `pure_nnx` (mirroring PR8's GRPO trainer). `add_entropy_to_checkpoint` dispatches across Linen `TrainState`, NNX `TrainStateNNX` Module, and post-split `nnx.State` shapes; all three produce identical `cos(1000*p)`/`sin(1000*p)` mu/nu replacements. - `generate_param_only_checkpoint`: `_read_train_checkpoint` builds an NNX `init_state_fn` under `pure_nnx`. New `_possibly_unroll_params_nnx` slices scanned NNX layers via dict-style mutation on `state.model.decoder`. New `_save_decode_checkpoint_nnx` writes a bf16 pure-dict tree to orbax. Parallel LoRA decode flow operates on the single-nested LoRA delta tree from PR8's `get_lora_abstract_state_nnx`. - `convert_gpt3_ckpt_from_paxml`: parallel NNX `state_map` keystr translation. Paths use the dict-style format that `jax.tree_util.keystr` actually produces on an `nnx.State` — e.g. `['optimizer']['step'].value` and `['optimizer']['opt_state'][0]['mu']<rest>.value` (dict-style for the State Mappings, `[0]` for the optax tuple, `.value` for the `nnx.Variable` leaf). Save uses `state.optimizer.step.value` for the step number on NNX. End-to-end paxml -> NNX conversion is wired but not yet validated on hardware. NNX-GRPO bug fixes from merged PR8: - `setup_train_loop`'s NNX `init_state_fn` called `TrainStateNNX(nnx_model, optimizer, reference_model=...)` — but `TrainStateNNX.__init__` only accepts `(model, optimizer)`, so this would raise `TypeError` the first time GRPO ran with `pure_nnx=True`. (The original code masked the failure with `# pylint: disable-next=unexpected-keyword-arg` rather than running it.) Fixed by constructing `TrainStateNNX` with the two valid args and setting `state.reference_model` as a sibling attribute after construction. `nnx.Module` is mutable, and the attribute survives `nnx.split` / `nnx.merge` round-trips. - `_train_step_nnx`'s `diff_wrapper` closed over `state.reference_model` directly inside `jax.value_and_grad`. `nnx.Module` is not a registered JAX pytree, so closure-capture only worked as long as JAX treated the module as static — fragile, and any internal-state touch during the reference forward would trace badly. Fixed by mirroring the existing policy-model pattern: `nnx.split(state.reference_model)` outside the wrapper, pass `ref_state` as an explicit pytree argument into `diff_wrapper`, and `nnx.merge` it inside. Tests: - `qk_clip_test`: 7 new NNX cases (`QKClipNNXTest`, `CalculateMaxLogitNNXTest`) covering attention-type guard, MLA wq_b/wkv_b math, both intermediate shapes, no-clip-below-threshold, missing-stats resilience, Linen<->NNX numeric parity. - `standalone_checkpointer_nnx_test` (new): 3 cases for adam mu/nu overwrite on `TrainStateNNX` Module shape, no mutation of `state.model` params, post-split `nnx.State` shape from `setup_training_state`. - `generate_param_only_checkpoint_nnx_test` (new): 3 cases for scanned-layer slicing (Llama-style; DeepSeek-style dense+moe split; LoRA delta unroll on the single-nested NNX shape). Test results: 31 passed, 2 skipped across the PR9 surface (`qk_clip_test`, `standalone_checkpointer_nnx_test`, `generate_param_only_checkpoint_nnx_test`, `grpo_nnx_test`).
NNX Migration Route Map
4.5. ✅ Linen↔NNX checkpoint conversion utilities (PR [NNX] NNX migration prep (4.5/N): Linen<->NNX checkpoint converter #3843)
9.5. ❌ NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix (stacked follow-up)
custom_vjpfor NNXTrueDescription
Closes the QK-Clip TODO and migrates three Linen-only checkpoint helpers (
standalone_checkpointer,generate_param_only_checkpoint,convert_gpt3_ckpt_from_paxml) to NNX. Every NNX edit is gated onconfig.pure_nnx; Linen paths preserved byte-for-byte.qk_clip_utils.apply_qk_clip_nnxmutatesstate.modelin place viannx.split→ pure-dict tree_map →nnx.replace_by_pure_dict→nnx.update.train.py::train_stepdispatches toapply_qk_clip_nnxon NNX. TODO at the call site is gone.standalone_checkpointer.add_entropy_to_checkpointdispatches across LinenTrainState, NNXTrainStateNNX, and post-splitnnx.State.generate_param_only_checkpointgets NNX_possibly_unroll_params_nnx+_save_decode_checkpoint_nnx, plus parallel LoRA decode flow.convert_gpt3_ckpt_from_paxmlkeystr_map uses the dict-style formatjax.tree_util.keystractually produces onnnx.State:['optimizer']['step'].value,['optimizer']['opt_state'][0]['mu']<rest>.value(dict for State,[0]for the optax tuple,.valuefor the Variable leaf).setup_train_loopcalledTrainStateNNX(model, optimizer, reference_model=...)— but__init__only accepts(model, optimizer). Fixed by settingstate.reference_modelas a sibling attribute after construction._train_step_nnx::diff_wrapperclosed overstate.reference_model(annnx.Module) insidejax.value_and_grad. Fixed bynnx.split-ing the reference outside and passingref_stateas an explicit arg, matching the existing policy-model pattern in the same function.Tests
qk_clip_test: 7 new NNX cases (attention-type guard, MLA wq_b/wkv_b math, both intermediate shapes, no-clip-below-threshold, missing-stats resilience, Linen↔NNX numeric parity).standalone_checkpointer_nnx_test(new): 3 cases.generate_param_only_checkpoint_nnx_test(new): 3 cases (Llama-style, DeepSeek-style, LoRA delta unroll).pure_nnx=Falsestays default.Results: 31 passed, 2 skipped (
qk_clip_test+ 2 new files +grpo_nnx_test).Stats
+999 / −232 across 9 files (2 new, 7 modified).
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.