Skip to content

Commit edf5d3f

Browse files
committed
NNX: QK-Clip + NNX-format checkpoint utilities + GRPO bug fixes from PR8
Migrates the remaining Linen-only utilities to NNX (QK-Clip, three checkpoint helpers) and picks up two real NNX-GRPO bugs flagged on the already-merged PR8 review. Every NNX edit is gated on `config.pure_nnx` or runtime state-shape detection; Linen paths are preserved byte-for-byte. QK-Clip: - `qk_clip_utils.apply_qk_clip_nnx` mutates `state.model` in place via `nnx.split` -> pure-dict `tree_map` -> `nnx.replace_by_pure_dict` -> `nnx.update`. Accepts both the production NNX intermediate shape (`self_attention.attention_op.max_logits`, sown inside `AttentionOp`) and the synthetic-fixture shape used by the existing Linen tests (`self_attention.max_logits`). - `train.py::train_step` dispatches on `isinstance(model, nn.Module)` to call `apply_qk_clip` (Linen) or `apply_qk_clip_nnx` (NNX). The TODO at the QK-Clip call site is removed. NNX-format checkpoint utilities: - `standalone_checkpointer.checkpoint_loop` builds an NNX `init_state_fn` under `pure_nnx` (mirroring PR8's GRPO trainer). `add_entropy_to_checkpoint` dispatches across Linen `TrainState`, NNX `TrainStateNNX` Module, and post-split `nnx.State` shapes; all three produce identical `cos(1000*p)`/`sin(1000*p)` mu/nu replacements. - `generate_param_only_checkpoint`: `_read_train_checkpoint` builds an NNX `init_state_fn` under `pure_nnx`. New `_possibly_unroll_params_nnx` slices scanned NNX layers via dict-style mutation on `state.model.decoder`. New `_save_decode_checkpoint_nnx` writes a bf16 pure-dict tree to orbax. Parallel LoRA decode flow operates on the single-nested LoRA delta tree from PR8's `get_lora_abstract_state_nnx`. - `convert_gpt3_ckpt_from_paxml`: parallel NNX `state_map` keystr translation. Paths use the dict-style format that `jax.tree_util.keystr` actually produces on an `nnx.State` — e.g. `['optimizer']['step'].value` and `['optimizer']['opt_state'][0]['mu']<rest>.value` (dict-style for the State Mappings, `[0]` for the optax tuple, `.value` for the `nnx.Variable` leaf). Save uses `state.optimizer.step.value` for the step number on NNX. End-to-end paxml -> NNX conversion is wired but not yet validated on hardware. NNX-GRPO bug fixes from merged PR8: - `setup_train_loop`'s NNX `init_state_fn` called `TrainStateNNX(nnx_model, optimizer, reference_model=...)` — but `TrainStateNNX.__init__` only accepts `(model, optimizer)`, so this would raise `TypeError` the first time GRPO ran with `pure_nnx=True`. (The original code masked the failure with `# pylint: disable-next=unexpected-keyword-arg` rather than running it.) Fixed by constructing `TrainStateNNX` with the two valid args and setting `state.reference_model` as a sibling attribute after construction. `nnx.Module` is mutable, and the attribute survives `nnx.split` / `nnx.merge` round-trips. - `_train_step_nnx`'s `diff_wrapper` closed over `state.reference_model` directly inside `jax.value_and_grad`. `nnx.Module` is not a registered JAX pytree, so closure-capture only worked as long as JAX treated the module as static — fragile, and any internal-state touch during the reference forward would trace badly. Fixed by mirroring the existing policy-model pattern: `nnx.split(state.reference_model)` outside the wrapper, pass `ref_state` as an explicit pytree argument into `diff_wrapper`, and `nnx.merge` it inside. Tests: - `qk_clip_test`: 7 new NNX cases (`QKClipNNXTest`, `CalculateMaxLogitNNXTest`) covering attention-type guard, MLA wq_b/wkv_b math, both intermediate shapes, no-clip-below-threshold, missing-stats resilience, Linen<->NNX numeric parity. - `standalone_checkpointer_nnx_test` (new): 3 cases for adam mu/nu overwrite on `TrainStateNNX` Module shape, no mutation of `state.model` params, post-split `nnx.State` shape from `setup_training_state`. - `generate_param_only_checkpoint_nnx_test` (new): 3 cases for scanned-layer slicing (Llama-style; DeepSeek-style dense+moe split; LoRA delta unroll on the single-nested NNX shape). Test results: 31 passed, 2 skipped across the PR9 surface (`qk_clip_test`, `standalone_checkpointer_nnx_test`, `generate_param_only_checkpoint_nnx_test`, `grpo_nnx_test`).
1 parent 56db4d5 commit edf5d3f

9 files changed

Lines changed: 926 additions & 159 deletions

File tree

src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py

Lines changed: 64 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import os
4141
import sys
4242

43+
from flax import nnx
4344
import jax
4445
from jax import random
4546
from jax.sharding import Mesh
@@ -48,11 +49,15 @@
4849
from maxtext.common import checkpointing
4950
from maxtext.common.common_types import MODEL_MODE_TRAIN
5051
from maxtext.layers import quantizations
52+
from maxtext.common import train_state_nnx
5153
from maxtext.models.models import transformer_as_linen
5254
from maxtext.optimizers import optimizers
5355
from maxtext.utils import max_logging
5456
from maxtext.utils import max_utils
5557
from maxtext.utils import maxtext_utils
58+
from maxtext.utils import maxtext_utils_nnx
59+
from maxtext.utils import model_creation_utils
60+
from maxtext.utils import train_utils
5661
import numpy as np
5762
from psutil import Process
5863
import tensorstore as ts
@@ -87,12 +92,23 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
8792
devices_array = maxtext_utils.create_device_mesh(cfg)
8893
mesh = Mesh(devices_array, cfg.mesh_axes)
8994

90-
# Output is Linen-format (keystr_map below uses Linen tree paths). Route to
91-
# Linen regardless of pure_nnx.
92-
quant = quantizations.configure_quantization(cfg)
93-
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
94-
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg)
95-
tx = optimizers.get_optimizer(cfg, learning_rate_schedule)
95+
if cfg.pure_nnx:
96+
rngs = maxtext_utils_nnx.create_nnx_rngs(cfg, rng_key=init_rng)
97+
model = model_creation_utils.from_config(cfg, mesh=mesh, rngs=rngs)
98+
_, tx = train_utils.create_training_optimizer(cfg, model)
99+
_create_model_partial, _ = model_creation_utils.create_nnx_abstract_model(cfg, mesh)
100+
101+
def init_state_fn():
102+
nnx_model = _create_model_partial()
103+
optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param)
104+
return train_state_nnx.TrainStateNNX(nnx_model, optimizer)
105+
106+
else:
107+
quant = quantizations.configure_quantization(cfg)
108+
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
109+
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg)
110+
tx = optimizers.get_optimizer(cfg, learning_rate_schedule)
111+
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng)
96112

97113
checkpoint_manager = checkpointing.create_orbax_checkpoint_manager(
98114
cfg.checkpoint_dir,
@@ -101,7 +117,6 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
101117
cfg.checkpoint_period,
102118
)
103119

104-
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng)
105120
state, _, _, _ = maxtext_utils.setup_training_state(None, cfg, mesh, checkpoint_manager, init_state_fn)
106121
max_logging.log("start")
107122
max_utils.print_mem_stats("After params initialized")
@@ -186,10 +201,24 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
186201
"['decoder']['decoder_norm']['bias']": (".params.lm.final_ln.bias", None),
187202
}
188203

189-
state_map = {
190-
".step": ("step", None),
191-
".opt_state.count": ("opt_states_0.no_prefix_0.count", None),
192-
}
204+
if cfg.pure_nnx:
205+
# NNX state-tree paths after `nnx.split(TrainStateNNX)`. The state is a
206+
# nested `nnx.State` (dict-like Mapping) with `nnx.Variable` leaves, so
207+
# `jax.tree_util.keystr` produces dict-style entries (`['key']`) plus
208+
# `.value` for the Variable leaf, plus `[idx]` for the optax tuple:
209+
# model params -> ['model']<rest>.value
210+
# adam mu / nu -> ['optimizer']['opt_state'][0]['mu' | 'nu']<rest>.value
211+
# step -> ['optimizer']['step'].value
212+
# opt count -> ['optimizer']['opt_state'][0]['count'].value
213+
state_map = {
214+
"['optimizer']['step'].value": ("step", None),
215+
"['optimizer']['opt_state'][0]['count'].value": ("opt_states_0.no_prefix_0.count", None),
216+
}
217+
else:
218+
state_map = {
219+
".step": ("step", None),
220+
".opt_state.count": ("opt_states_0.no_prefix_0.count", None),
221+
}
193222

194223
def get_layer_prefix(keystr_pax):
195224
# different path format between decoder_layer variable
@@ -201,19 +230,27 @@ def get_layer_prefix(keystr_pax):
201230
return prefix_pax_opt_state
202231

203232
for keystr_maxtext, (keystr_pax, transform_fn) in keystr_map.items():
204-
# model variable
205-
state_map[f".params['params']{keystr_maxtext}"] = (f"mdl_vars{keystr_pax}", transform_fn)
206233
prefix_pax_opt_state = get_layer_prefix(keystr_pax)
207-
# first momentum in optimizer state
208-
state_map[f".opt_state.mu['params']{keystr_maxtext}"] = (
209-
f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}",
210-
transform_fn,
211-
)
212-
# second momentum in optimizer state
213-
state_map[f".opt_state.nu['params']{keystr_maxtext}"] = (
214-
f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}",
215-
transform_fn,
216-
)
234+
if cfg.pure_nnx:
235+
state_map[f"['model']{keystr_maxtext}.value"] = (f"mdl_vars{keystr_pax}", transform_fn)
236+
state_map[f"['optimizer']['opt_state'][0]['mu']{keystr_maxtext}.value"] = (
237+
f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}",
238+
transform_fn,
239+
)
240+
state_map[f"['optimizer']['opt_state'][0]['nu']{keystr_maxtext}.value"] = (
241+
f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}",
242+
transform_fn,
243+
)
244+
else:
245+
state_map[f".params['params']{keystr_maxtext}"] = (f"mdl_vars{keystr_pax}", transform_fn)
246+
state_map[f".opt_state.mu['params']{keystr_maxtext}"] = (
247+
f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}",
248+
transform_fn,
249+
)
250+
state_map[f".opt_state.nu['params']{keystr_maxtext}"] = (
251+
f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}",
252+
transform_fn,
253+
)
217254

218255
def verify_fn(key_path, _):
219256
keystr = jax.tree_util.keystr(key_path)
@@ -265,10 +302,11 @@ def map_fn(key_path, value):
265302
max_logging.log("converted state finished")
266303
max_utils.print_mem_stats("converted state finished")
267304

268-
if checkpointing.save_checkpoint(checkpoint_manager, converted_state.step, converted_state):
269-
max_logging.log(f"saved a checkpoint at step {converted_state.step}")
305+
step_value = int(converted_state.optimizer.step.value) if cfg.pure_nnx else converted_state.step
306+
if checkpointing.save_checkpoint(checkpoint_manager, step_value, converted_state):
307+
max_logging.log(f"saved a checkpoint at step {step_value}")
270308
# Upon preemption, exit when and only when all ongoing saves are complete.
271-
if checkpoint_manager.reached_preemption(converted_state.step):
309+
if checkpoint_manager.reached_preemption(step_value):
272310
checkpoint_manager.wait_until_finished()
273311
sys.exit()
274312

src/maxtext/experimental/rl/grpo_trainer.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -486,15 +486,22 @@ def _train_step_nnx(model_graphdef, config, state_mesh_shardings, state, data):
486486

487487
state = nnx.merge(model_graphdef, state) # Reconstruct the TrainStateNNX.
488488
policy_graphdef, curr_params, rest = nnx.split(state.model, nnx.Param, ...)
489-
490-
def diff_wrapper(param, rest, config, data):
489+
# Split the reference model into (graphdef, state) so we pass `ref_state` as
490+
# an explicit pytree-typed argument to `diff_wrapper` instead of closing over
491+
# the mutable nnx.Module — closure capture inside jax.value_and_grad works
492+
# only by accident (Modules aren't registered JAX pytrees) and breaks the
493+
# moment the reference forward touches any internal state.
494+
ref_graphdef, ref_state = nnx.split(state.reference_model)
495+
496+
def diff_wrapper(param, rest, ref_state, config, data):
491497
local_model = nnx.merge(policy_graphdef, param, rest, copy=True)
492-
loss, aux = grpo_loss_fn_nnx(local_model, config, data, None, None, state.reference_model, is_train=True)
498+
local_ref = nnx.merge(ref_graphdef, ref_state, copy=True)
499+
loss, aux = grpo_loss_fn_nnx(local_model, config, data, None, None, local_ref, is_train=True)
493500
_, _, new_rest = nnx.split(local_model, nnx.Param, ...)
494501
return loss, (aux, new_rest)
495502

496503
grad_func = jax.value_and_grad(diff_wrapper, argnums=0, has_aux=True)
497-
(loss, (aux, new_rest)), raw_grads = grad_func(curr_params, rest, config, data)
504+
(loss, (aux, new_rest)), raw_grads = grad_func(curr_params, rest, ref_state, config, data)
498505
nnx.update(state.model, new_rest)
499506

500507
if config.gradient_clipping_threshold > 0:
@@ -798,8 +805,11 @@ def init_state_fn():
798805
optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param)
799806
# Reference uses the same init seed so it starts identical to the policy.
800807
reference_model = _create_model_partial()
801-
# pylint: disable-next=unexpected-keyword-arg
802-
return train_state_nnx.TrainStateNNX(nnx_model, optimizer, reference_model=reference_model)
808+
# TrainStateNNX only takes (model, optimizer); reference_model is an NNX
809+
# sibling attribute set after construction (nnx.Module is mutable).
810+
state = train_state_nnx.TrainStateNNX(nnx_model, optimizer)
811+
state.reference_model = reference_model
812+
return state
803813

804814
else:
805815
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng)

src/maxtext/trainers/pre_train/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -499,11 +499,11 @@ def move(path, value):
499499
"learning/total_weights": total_weights,
500500
}
501501
if config.use_qk_clip:
502-
# Apply QK-Clip (Linen path only; NNX uses different state layout — TODO: implement for NNX)
503502
if isinstance(model, nn.Module):
504503
new_state = qk_clip_utils.apply_qk_clip(new_state, intermediate_outputs, config)
504+
else:
505+
new_state = qk_clip_utils.apply_qk_clip_nnx(new_state, intermediate_outputs, config)
505506

506-
# Report max_logits metric
507507
global_max_logit = qk_clip_utils.calculate_max_logit_metric(intermediate_outputs)
508508
if global_max_logit is not None:
509509
scalar_metrics["learning/max_logits"] = global_max_logit

0 commit comments

Comments
 (0)