Skip to content

Commit 2a7775a

Browse files
committed
NNX: QK-Clip on NNX + NNX-format checkpoint utilities
Closes the QK-Clip TODO and migrates the remaining Linen-only checkpoint utilities to NNX. Linen paths preserved byte-for-byte (every NNX edit is gated on `config.pure_nnx` or runtime state-shape detection). QK-Clip: - qk_clip_utils.apply_qk_clip_nnx mutates state.model in place via nnx.split -> pure-dict tree_map -> nnx.replace_by_pure_dict -> nnx.update. Accepts both the production NNX intermediate shape (self_attention.attention_op.max_logits) and the synthetic-fixture shape from the existing Linen tests (self_attention.max_logits). - train.py train_step dispatches to apply_qk_clip_nnx for NNX, removing the prior TODO at the QK-Clip call site. Checkpoint utilities (NNX paths added): - standalone_checkpointer.checkpoint_loop builds an NNX init_state_fn under pure_nnx; add_entropy_to_checkpoint dispatches across Linen TrainState, NNX TrainStateNNX Module, and post-split nnx.State shapes. - generate_param_only_checkpoint: NNX init_state_fn under pure_nnx; _possibly_unroll_params_nnx slices scanned NNX layers via dict-style state mutation; _save_decode_checkpoint_nnx writes a bf16 pure-dict tree to orbax. Parallel LoRA decode flow operates on the single-nested LoRA delta tree from PR8's get_lora_abstract_state_nnx. - convert_gpt3_ckpt_from_paxml: parallel NNX state_map keystr translation (.params['params']<rest> -> .model<rest>.value, etc.). End-to-end paxml -> NNX conversion is wired but not yet validated on hardware. Tests: - qk_clip_test: 7 new NNX cases covering attention-type guard, MLA wq_b/wkv_b math, both intermediate shapes, no-clip-below-threshold, missing-stats resilience, Linen<->NNX numeric parity. - standalone_checkpointer_nnx_test (new): 3 cases for adam mu/nu overwrite on TrainStateNNX Module shape, no mutation of state.model params, post-split nnx.State shape from setup_training_state. - generate_param_only_checkpoint_nnx_test (new): 3 cases for scanned layer slicing (Llama-style; DeepSeek-style dense+moe split; LoRA delta unroll on the single-nested NNX shape). NNX + AQT in MaxEngine and the layerwise_quantization NNX path are split into the follow-up PR9.5.
1 parent 78049f9 commit 2a7775a

8 files changed

Lines changed: 907 additions & 161 deletions

File tree

src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py

Lines changed: 61 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import os
4141
import sys
4242

43+
from flax import nnx
4344
import jax
4445
from jax import random
4546
from jax.sharding import Mesh
@@ -48,11 +49,15 @@
4849
from maxtext.common import checkpointing
4950
from maxtext.common.common_types import MODEL_MODE_TRAIN
5051
from maxtext.layers import quantizations
52+
from maxtext.layers import train_state_nnx
5153
from maxtext.models.models import transformer_as_linen
5254
from maxtext.optimizers import optimizers
5355
from maxtext.utils import max_logging
5456
from maxtext.utils import max_utils
5557
from maxtext.utils import maxtext_utils
58+
from maxtext.utils import maxtext_utils_nnx
59+
from maxtext.utils import model_creation_utils
60+
from maxtext.utils import train_utils
5661
import numpy as np
5762
from psutil import Process
5863
import tensorstore as ts
@@ -87,14 +92,23 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
8792
devices_array = maxtext_utils.create_device_mesh(cfg)
8893
mesh = Mesh(devices_array, cfg.mesh_axes)
8994

90-
# This conversion script reads paxml-format weights and emits a Linen-format
91-
# MaxText checkpoint (downstream uses `.params['params']`, `.opt_state.mu['params']`,
92-
# `.opt_state.nu['params']` keystr paths; the keystr_map below targets the Linen
93-
# tree shape). Use the Linen path regardless of pure_nnx.
94-
quant = quantizations.configure_quantization(cfg)
95-
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
96-
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg)
97-
tx = optimizers.get_optimizer(cfg, learning_rate_schedule)
95+
if cfg.pure_nnx:
96+
rngs = maxtext_utils_nnx.create_nnx_rngs(cfg, rng_key=init_rng)
97+
model = model_creation_utils.from_config(cfg, mesh=mesh, rngs=rngs)
98+
_, tx = train_utils.create_training_optimizer(cfg, model)
99+
_create_model_partial, _ = model_creation_utils.create_nnx_abstract_model(cfg, mesh)
100+
101+
def init_state_fn():
102+
nnx_model = _create_model_partial()
103+
optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param)
104+
return train_state_nnx.TrainStateNNX(nnx_model, optimizer)
105+
106+
else:
107+
quant = quantizations.configure_quantization(cfg)
108+
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
109+
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg)
110+
tx = optimizers.get_optimizer(cfg, learning_rate_schedule)
111+
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng)
98112

99113
checkpoint_manager = checkpointing.create_orbax_checkpoint_manager(
100114
cfg.checkpoint_dir,
@@ -103,7 +117,6 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
103117
cfg.checkpoint_period,
104118
)
105119

106-
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng)
107120
state, _, _, _ = maxtext_utils.setup_training_state(None, cfg, mesh, checkpoint_manager, init_state_fn)
108121
max_logging.log("start")
109122
max_utils.print_mem_stats("After params initialized")
@@ -188,10 +201,21 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
188201
"['decoder']['decoder_norm']['bias']": (".params.lm.final_ln.bias", None),
189202
}
190203

191-
state_map = {
192-
".step": ("step", None),
193-
".opt_state.count": ("opt_states_0.no_prefix_0.count", None),
194-
}
204+
if cfg.pure_nnx:
205+
# NNX state-tree paths after `nnx.split(TrainStateNNX)`:
206+
# model params -> ['model']<rest>.value
207+
# adam mu / nu -> ['optimizer']['opt_state']['mu' | 'nu']<rest>.value
208+
# step -> ['optimizer']['step'].value
209+
# opt count -> ['optimizer']['opt_state']['count'].value
210+
state_map = {
211+
".optimizer.step.value": ("step", None),
212+
".optimizer.opt_state.count.value": ("opt_states_0.no_prefix_0.count", None),
213+
}
214+
else:
215+
state_map = {
216+
".step": ("step", None),
217+
".opt_state.count": ("opt_states_0.no_prefix_0.count", None),
218+
}
195219

196220
def get_layer_prefix(keystr_pax):
197221
# different path format between decoder_layer variable
@@ -203,19 +227,27 @@ def get_layer_prefix(keystr_pax):
203227
return prefix_pax_opt_state
204228

205229
for keystr_maxtext, (keystr_pax, transform_fn) in keystr_map.items():
206-
# model variable
207-
state_map[f".params['params']{keystr_maxtext}"] = (f"mdl_vars{keystr_pax}", transform_fn)
208230
prefix_pax_opt_state = get_layer_prefix(keystr_pax)
209-
# first momentum in optimizer state
210-
state_map[f".opt_state.mu['params']{keystr_maxtext}"] = (
211-
f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}",
212-
transform_fn,
213-
)
214-
# second momentum in optimizer state
215-
state_map[f".opt_state.nu['params']{keystr_maxtext}"] = (
216-
f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}",
217-
transform_fn,
218-
)
231+
if cfg.pure_nnx:
232+
state_map[f".model{keystr_maxtext}.value"] = (f"mdl_vars{keystr_pax}", transform_fn)
233+
state_map[f".optimizer.opt_state.mu{keystr_maxtext}.value"] = (
234+
f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}",
235+
transform_fn,
236+
)
237+
state_map[f".optimizer.opt_state.nu{keystr_maxtext}.value"] = (
238+
f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}",
239+
transform_fn,
240+
)
241+
else:
242+
state_map[f".params['params']{keystr_maxtext}"] = (f"mdl_vars{keystr_pax}", transform_fn)
243+
state_map[f".opt_state.mu['params']{keystr_maxtext}"] = (
244+
f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}",
245+
transform_fn,
246+
)
247+
state_map[f".opt_state.nu['params']{keystr_maxtext}"] = (
248+
f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}",
249+
transform_fn,
250+
)
219251

220252
def verify_fn(key_path, _):
221253
keystr = jax.tree_util.keystr(key_path)
@@ -267,10 +299,11 @@ def map_fn(key_path, value):
267299
max_logging.log("converted state finished")
268300
max_utils.print_mem_stats("converted state finished")
269301

270-
if checkpointing.save_checkpoint(checkpoint_manager, converted_state.step, converted_state):
271-
max_logging.log(f"saved a checkpoint at step {converted_state.step}")
302+
step_value = int(converted_state.optimizer.step.value) if cfg.pure_nnx else converted_state.step
303+
if checkpointing.save_checkpoint(checkpoint_manager, step_value, converted_state):
304+
max_logging.log(f"saved a checkpoint at step {step_value}")
272305
# Upon preemption, exit when and only when all ongoing saves are complete.
273-
if checkpoint_manager.reached_preemption(converted_state.step):
306+
if checkpoint_manager.reached_preemption(step_value):
274307
checkpoint_manager.wait_until_finished()
275308
sys.exit()
276309

src/maxtext/trainers/pre_train/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -530,11 +530,11 @@ def move(path, value):
530530
"learning/total_weights": total_weights,
531531
}
532532
if config.use_qk_clip:
533-
# Apply QK-Clip (Linen path only; NNX uses different state layout — TODO: implement for NNX)
534533
if isinstance(model, nn.Module):
535534
new_state = qk_clip_utils.apply_qk_clip(new_state, intermediate_outputs, config)
535+
else:
536+
new_state = qk_clip_utils.apply_qk_clip_nnx(new_state, intermediate_outputs, config)
536537

537-
# Report max_logits metric
538538
global_max_logit = qk_clip_utils.calculate_max_logit_metric(intermediate_outputs)
539539
if global_max_logit is not None:
540540
scalar_metrics["learning/max_logits"] = global_max_logit

0 commit comments

Comments
 (0)