Skip to content

Commit a4b9db9

Browse files
committed
NNX: QK-Clip on NNX + NNX-format checkpoint utilities
Closes the QK-Clip TODO and migrates the remaining Linen-only checkpoint utilities to NNX. Linen paths preserved byte-for-byte (every NNX edit is gated on `config.pure_nnx` or runtime state-shape detection). QK-Clip: - qk_clip_utils.apply_qk_clip_nnx mutates state.model in place via nnx.split -> pure-dict tree_map -> nnx.replace_by_pure_dict -> nnx.update. Accepts both the production NNX intermediate shape (self_attention.attention_op.max_logits) and the synthetic-fixture shape from the existing Linen tests (self_attention.max_logits). - train.py train_step dispatches to apply_qk_clip_nnx for NNX, removing the prior TODO at the QK-Clip call site. Checkpoint utilities (NNX paths added): - standalone_checkpointer.checkpoint_loop builds an NNX init_state_fn under pure_nnx; add_entropy_to_checkpoint dispatches across Linen TrainState, NNX TrainStateNNX Module, and post-split nnx.State shapes. - generate_param_only_checkpoint: NNX init_state_fn under pure_nnx; _possibly_unroll_params_nnx slices scanned NNX layers via dict-style state mutation; _save_decode_checkpoint_nnx writes a bf16 pure-dict tree to orbax. Parallel LoRA decode flow operates on the single-nested LoRA delta tree from PR8's get_lora_abstract_state_nnx. - convert_gpt3_ckpt_from_paxml: parallel NNX state_map keystr translation (.params['params']<rest> -> .model<rest>.value, etc.). End-to-end paxml -> NNX conversion is wired but not yet validated on hardware. Tests: - qk_clip_test: 7 new NNX cases covering attention-type guard, MLA wq_b/wkv_b math, both intermediate shapes, no-clip-below-threshold, missing-stats resilience, Linen<->NNX numeric parity. - standalone_checkpointer_nnx_test (new): 3 cases for adam mu/nu overwrite on TrainStateNNX Module shape, no mutation of state.model params, post-split nnx.State shape from setup_training_state. - generate_param_only_checkpoint_nnx_test (new): 3 cases for scanned layer slicing (Llama-style; DeepSeek-style dense+moe split; LoRA delta unroll on the single-nested NNX shape). NNX + AQT in MaxEngine and the layerwise_quantization NNX path are split into the follow-up PR9.5.
1 parent c62cbbb commit a4b9db9

8 files changed

Lines changed: 907 additions & 153 deletions

File tree

src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py

Lines changed: 61 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import os
4141
import sys
4242

43+
from flax import nnx
4344
import jax
4445
from jax import random
4546
from jax.sharding import Mesh
@@ -48,11 +49,15 @@
4849
from maxtext.common import checkpointing
4950
from maxtext.common.common_types import MODEL_MODE_TRAIN
5051
from maxtext.layers import quantizations
52+
from maxtext.layers import train_state_nnx
5153
from maxtext.models.models import transformer_as_linen
5254
from maxtext.optimizers import optimizers
5355
from maxtext.utils import max_logging
5456
from maxtext.utils import max_utils
5557
from maxtext.utils import maxtext_utils
58+
from maxtext.utils import maxtext_utils_nnx
59+
from maxtext.utils import model_creation_utils
60+
from maxtext.utils import train_utils
5661
import numpy as np
5762
from psutil import Process
5863
import tensorstore as ts
@@ -87,12 +92,23 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
8792
devices_array = maxtext_utils.create_device_mesh(cfg)
8893
mesh = Mesh(devices_array, cfg.mesh_axes)
8994

90-
# Output is Linen-format (keystr_map below uses Linen tree paths). Route to
91-
# Linen regardless of pure_nnx.
92-
quant = quantizations.configure_quantization(cfg)
93-
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
94-
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg)
95-
tx = optimizers.get_optimizer(cfg, learning_rate_schedule)
95+
if cfg.pure_nnx:
96+
rngs = maxtext_utils_nnx.create_nnx_rngs(cfg, rng_key=init_rng)
97+
model = model_creation_utils.from_config(cfg, mesh=mesh, rngs=rngs)
98+
_, tx = train_utils.create_training_optimizer(cfg, model)
99+
_create_model_partial, _ = model_creation_utils.create_nnx_abstract_model(cfg, mesh)
100+
101+
def init_state_fn():
102+
nnx_model = _create_model_partial()
103+
optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param)
104+
return train_state_nnx.TrainStateNNX(nnx_model, optimizer)
105+
106+
else:
107+
quant = quantizations.configure_quantization(cfg)
108+
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
109+
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg)
110+
tx = optimizers.get_optimizer(cfg, learning_rate_schedule)
111+
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng)
96112

97113
checkpoint_manager = checkpointing.create_orbax_checkpoint_manager(
98114
cfg.checkpoint_dir,
@@ -101,7 +117,6 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
101117
cfg.checkpoint_period,
102118
)
103119

104-
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng)
105120
state, _, _, _ = maxtext_utils.setup_training_state(None, cfg, mesh, checkpoint_manager, init_state_fn)
106121
max_logging.log("start")
107122
max_utils.print_mem_stats("After params initialized")
@@ -186,10 +201,21 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
186201
"['decoder']['decoder_norm']['bias']": (".params.lm.final_ln.bias", None),
187202
}
188203

189-
state_map = {
190-
".step": ("step", None),
191-
".opt_state.count": ("opt_states_0.no_prefix_0.count", None),
192-
}
204+
if cfg.pure_nnx:
205+
# NNX state-tree paths after `nnx.split(TrainStateNNX)`:
206+
# model params -> ['model']<rest>.value
207+
# adam mu / nu -> ['optimizer']['opt_state']['mu' | 'nu']<rest>.value
208+
# step -> ['optimizer']['step'].value
209+
# opt count -> ['optimizer']['opt_state']['count'].value
210+
state_map = {
211+
".optimizer.step.value": ("step", None),
212+
".optimizer.opt_state.count.value": ("opt_states_0.no_prefix_0.count", None),
213+
}
214+
else:
215+
state_map = {
216+
".step": ("step", None),
217+
".opt_state.count": ("opt_states_0.no_prefix_0.count", None),
218+
}
193219

194220
def get_layer_prefix(keystr_pax):
195221
# different path format between decoder_layer variable
@@ -201,19 +227,27 @@ def get_layer_prefix(keystr_pax):
201227
return prefix_pax_opt_state
202228

203229
for keystr_maxtext, (keystr_pax, transform_fn) in keystr_map.items():
204-
# model variable
205-
state_map[f".params['params']{keystr_maxtext}"] = (f"mdl_vars{keystr_pax}", transform_fn)
206230
prefix_pax_opt_state = get_layer_prefix(keystr_pax)
207-
# first momentum in optimizer state
208-
state_map[f".opt_state.mu['params']{keystr_maxtext}"] = (
209-
f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}",
210-
transform_fn,
211-
)
212-
# second momentum in optimizer state
213-
state_map[f".opt_state.nu['params']{keystr_maxtext}"] = (
214-
f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}",
215-
transform_fn,
216-
)
231+
if cfg.pure_nnx:
232+
state_map[f".model{keystr_maxtext}.value"] = (f"mdl_vars{keystr_pax}", transform_fn)
233+
state_map[f".optimizer.opt_state.mu{keystr_maxtext}.value"] = (
234+
f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}",
235+
transform_fn,
236+
)
237+
state_map[f".optimizer.opt_state.nu{keystr_maxtext}.value"] = (
238+
f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}",
239+
transform_fn,
240+
)
241+
else:
242+
state_map[f".params['params']{keystr_maxtext}"] = (f"mdl_vars{keystr_pax}", transform_fn)
243+
state_map[f".opt_state.mu['params']{keystr_maxtext}"] = (
244+
f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}",
245+
transform_fn,
246+
)
247+
state_map[f".opt_state.nu['params']{keystr_maxtext}"] = (
248+
f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}",
249+
transform_fn,
250+
)
217251

218252
def verify_fn(key_path, _):
219253
keystr = jax.tree_util.keystr(key_path)
@@ -265,10 +299,11 @@ def map_fn(key_path, value):
265299
max_logging.log("converted state finished")
266300
max_utils.print_mem_stats("converted state finished")
267301

268-
if checkpointing.save_checkpoint(checkpoint_manager, converted_state.step, converted_state):
269-
max_logging.log(f"saved a checkpoint at step {converted_state.step}")
302+
step_value = int(converted_state.optimizer.step.value) if cfg.pure_nnx else converted_state.step
303+
if checkpointing.save_checkpoint(checkpoint_manager, step_value, converted_state):
304+
max_logging.log(f"saved a checkpoint at step {step_value}")
270305
# Upon preemption, exit when and only when all ongoing saves are complete.
271-
if checkpoint_manager.reached_preemption(converted_state.step):
306+
if checkpoint_manager.reached_preemption(step_value):
272307
checkpoint_manager.wait_until_finished()
273308
sys.exit()
274309

src/maxtext/trainers/pre_train/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -526,11 +526,11 @@ def move(path, value):
526526
"learning/total_weights": total_weights,
527527
}
528528
if config.use_qk_clip:
529-
# Apply QK-Clip (Linen path only; NNX uses different state layout — TODO: implement for NNX)
530529
if isinstance(model, nn.Module):
531530
new_state = qk_clip_utils.apply_qk_clip(new_state, intermediate_outputs, config)
531+
else:
532+
new_state = qk_clip_utils.apply_qk_clip_nnx(new_state, intermediate_outputs, config)
532533

533-
# Report max_logits metric
534534
global_max_logit = qk_clip_utils.calculate_max_logit_metric(intermediate_outputs)
535535
if global_max_logit is not None:
536536
scalar_metrics["learning/max_logits"] = global_max_logit

0 commit comments

Comments
 (0)