Commit edf5d3f
committed
NNX: QK-Clip + NNX-format checkpoint utilities + GRPO bug fixes from PR8
Migrates the remaining Linen-only utilities to NNX (QK-Clip, three
checkpoint helpers) and picks up two real NNX-GRPO bugs flagged on the
already-merged PR8 review. Every NNX edit is gated on
`config.pure_nnx` or runtime state-shape detection; Linen paths are
preserved byte-for-byte.
QK-Clip:
- `qk_clip_utils.apply_qk_clip_nnx` mutates `state.model` in place via
`nnx.split` -> pure-dict `tree_map` -> `nnx.replace_by_pure_dict` ->
`nnx.update`. Accepts both the production NNX intermediate shape
(`self_attention.attention_op.max_logits`, sown inside `AttentionOp`)
and the synthetic-fixture shape used by the existing Linen tests
(`self_attention.max_logits`).
- `train.py::train_step` dispatches on `isinstance(model, nn.Module)`
to call `apply_qk_clip` (Linen) or `apply_qk_clip_nnx` (NNX). The
TODO at the QK-Clip call site is removed.
NNX-format checkpoint utilities:
- `standalone_checkpointer.checkpoint_loop` builds an NNX
`init_state_fn` under `pure_nnx` (mirroring PR8's GRPO trainer).
`add_entropy_to_checkpoint` dispatches across Linen `TrainState`,
NNX `TrainStateNNX` Module, and post-split `nnx.State` shapes; all
three produce identical `cos(1000*p)`/`sin(1000*p)` mu/nu
replacements.
- `generate_param_only_checkpoint`: `_read_train_checkpoint` builds an
NNX `init_state_fn` under `pure_nnx`. New `_possibly_unroll_params_nnx`
slices scanned NNX layers via dict-style mutation on
`state.model.decoder`. New `_save_decode_checkpoint_nnx` writes a
bf16 pure-dict tree to orbax. Parallel LoRA decode flow operates on
the single-nested LoRA delta tree from PR8's
`get_lora_abstract_state_nnx`.
- `convert_gpt3_ckpt_from_paxml`: parallel NNX `state_map` keystr
translation. Paths use the dict-style format that
`jax.tree_util.keystr` actually produces on an `nnx.State` — e.g.
`['optimizer']['step'].value` and
`['optimizer']['opt_state'][0]['mu']<rest>.value` (dict-style for
the State Mappings, `[0]` for the optax tuple, `.value` for the
`nnx.Variable` leaf). Save uses `state.optimizer.step.value` for
the step number on NNX. End-to-end paxml -> NNX conversion is wired
but not yet validated on hardware.
NNX-GRPO bug fixes from merged PR8:
- `setup_train_loop`'s NNX `init_state_fn` called
`TrainStateNNX(nnx_model, optimizer, reference_model=...)` — but
`TrainStateNNX.__init__` only accepts `(model, optimizer)`, so this
would raise `TypeError` the first time GRPO ran with
`pure_nnx=True`. (The original code masked the failure with
`# pylint: disable-next=unexpected-keyword-arg` rather than running
it.) Fixed by constructing `TrainStateNNX` with the two valid args
and setting `state.reference_model` as a sibling attribute after
construction. `nnx.Module` is mutable, and the attribute survives
`nnx.split` / `nnx.merge` round-trips.
- `_train_step_nnx`'s `diff_wrapper` closed over `state.reference_model`
directly inside `jax.value_and_grad`. `nnx.Module` is not a
registered JAX pytree, so closure-capture only worked as long as
JAX treated the module as static — fragile, and any internal-state
touch during the reference forward would trace badly. Fixed by
mirroring the existing policy-model pattern:
`nnx.split(state.reference_model)` outside the wrapper, pass
`ref_state` as an explicit pytree argument into `diff_wrapper`, and
`nnx.merge` it inside.
Tests:
- `qk_clip_test`: 7 new NNX cases (`QKClipNNXTest`,
`CalculateMaxLogitNNXTest`) covering attention-type guard, MLA
wq_b/wkv_b math, both intermediate shapes, no-clip-below-threshold,
missing-stats resilience, Linen<->NNX numeric parity.
- `standalone_checkpointer_nnx_test` (new): 3 cases for adam mu/nu
overwrite on `TrainStateNNX` Module shape, no mutation of
`state.model` params, post-split `nnx.State` shape from
`setup_training_state`.
- `generate_param_only_checkpoint_nnx_test` (new): 3 cases for
scanned-layer slicing (Llama-style; DeepSeek-style dense+moe split;
LoRA delta unroll on the single-nested NNX shape).
Test results: 31 passed, 2 skipped across the PR9 surface
(`qk_clip_test`, `standalone_checkpointer_nnx_test`,
`generate_param_only_checkpoint_nnx_test`, `grpo_nnx_test`).1 parent 56db4d5 commit edf5d3f
9 files changed
Lines changed: 926 additions & 159 deletions
File tree
- src/maxtext
- checkpoint_conversion/standalone_scripts
- experimental/rl
- trainers/pre_train
- utils
- tests/unit
Lines changed: 64 additions & 26 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
40 | 40 | | |
41 | 41 | | |
42 | 42 | | |
| 43 | + | |
43 | 44 | | |
44 | 45 | | |
45 | 46 | | |
| |||
48 | 49 | | |
49 | 50 | | |
50 | 51 | | |
| 52 | + | |
51 | 53 | | |
52 | 54 | | |
53 | 55 | | |
54 | 56 | | |
55 | 57 | | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
56 | 61 | | |
57 | 62 | | |
58 | 63 | | |
| |||
87 | 92 | | |
88 | 93 | | |
89 | 94 | | |
90 | | - | |
91 | | - | |
92 | | - | |
93 | | - | |
94 | | - | |
95 | | - | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
96 | 112 | | |
97 | 113 | | |
98 | 114 | | |
| |||
101 | 117 | | |
102 | 118 | | |
103 | 119 | | |
104 | | - | |
105 | 120 | | |
106 | 121 | | |
107 | 122 | | |
| |||
186 | 201 | | |
187 | 202 | | |
188 | 203 | | |
189 | | - | |
190 | | - | |
191 | | - | |
192 | | - | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
193 | 222 | | |
194 | 223 | | |
195 | 224 | | |
| |||
201 | 230 | | |
202 | 231 | | |
203 | 232 | | |
204 | | - | |
205 | | - | |
206 | 233 | | |
207 | | - | |
208 | | - | |
209 | | - | |
210 | | - | |
211 | | - | |
212 | | - | |
213 | | - | |
214 | | - | |
215 | | - | |
216 | | - | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
217 | 254 | | |
218 | 255 | | |
219 | 256 | | |
| |||
265 | 302 | | |
266 | 303 | | |
267 | 304 | | |
268 | | - | |
269 | | - | |
| 305 | + | |
| 306 | + | |
| 307 | + | |
270 | 308 | | |
271 | | - | |
| 309 | + | |
272 | 310 | | |
273 | 311 | | |
274 | 312 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
486 | 486 | | |
487 | 487 | | |
488 | 488 | | |
489 | | - | |
490 | | - | |
| 489 | + | |
| 490 | + | |
| 491 | + | |
| 492 | + | |
| 493 | + | |
| 494 | + | |
| 495 | + | |
| 496 | + | |
491 | 497 | | |
492 | | - | |
| 498 | + | |
| 499 | + | |
493 | 500 | | |
494 | 501 | | |
495 | 502 | | |
496 | 503 | | |
497 | | - | |
| 504 | + | |
498 | 505 | | |
499 | 506 | | |
500 | 507 | | |
| |||
798 | 805 | | |
799 | 806 | | |
800 | 807 | | |
801 | | - | |
802 | | - | |
| 808 | + | |
| 809 | + | |
| 810 | + | |
| 811 | + | |
| 812 | + | |
803 | 813 | | |
804 | 814 | | |
805 | 815 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
499 | 499 | | |
500 | 500 | | |
501 | 501 | | |
502 | | - | |
503 | 502 | | |
504 | 503 | | |
| 504 | + | |
| 505 | + | |
505 | 506 | | |
506 | | - | |
507 | 507 | | |
508 | 508 | | |
509 | 509 | | |
| |||
0 commit comments