Skip to content

Commit 6b6e61b

Browse files
committed
NNX: add sharding tools, Linen<->NNX checkpoint utilities, and post-training fixes
Part 1 — sharding diagnostics and Linen<->NNX checkpoint utilities: - modify print_shardings_params to support NNX (maxtext_utils.py) - add --pure_nnx flag to run_sharding_dump.py - add bidirectional Linen<->NNX checkpoint conversion utility (linen_nnx_converter.py) - add checkpoint comparison utility for Linen vs NNX validation (compare_linen_nnx_checkpoint.py) Part 2 — post-training bug fixes: - models.py: unpack MultimodalInput before passing to NNXDecoder (was passing the whole object as multimodal_input= kwarg; NNXDecoder only accepts individual fields) - optimizers.py: guard adam_pax against scalar LR from optax.inject_hyperparams (callable() check before invoking learning_rate_fn) - train_distill.py: fix nested NNX transform issue (nnx.value_and_grad inside nnx.jit raises conflicting outer_index error); refactored to jax.value_and_grad + explicit nnx.split/merge pattern; teacher inference moved outside value_and_grad
1 parent fe7d3e5 commit 6b6e61b

13 files changed

Lines changed: 2823 additions & 81 deletions

File tree

src/maxtext/checkpoint_conversion/compare_linen_nnx_checkpoint.py

Lines changed: 609 additions & 0 deletions
Large diffs are not rendered by default.

src/maxtext/checkpoint_conversion/linen_nnx_converter.py

Lines changed: 581 additions & 0 deletions
Large diffs are not rendered by default.

src/maxtext/models/models.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,11 @@ def __call__(
520520
previous_chunk=previous_chunk,
521521
slot=slot,
522522
page_state=page_state,
523-
multimodal_input=multimodal_input,
523+
image_embeddings=multimodal_input.image_embeddings if multimodal_input is not None else None,
524+
image_masks=multimodal_input.image_masks if multimodal_input is not None else None,
525+
audio_embeddings=multimodal_input.audio_embeddings if multimodal_input is not None else None,
526+
audio_masks=multimodal_input.audio_masks if multimodal_input is not None else None,
527+
bidirectional_mask=multimodal_input.bidirectional_mask if multimodal_input is not None else None,
524528
kv_caches=kv_caches,
525529
attention_metadata=attention_metadata,
526530
deepstack_visual_embeds=deepstack_visual_embeds,

src/maxtext/optimizers/optimizers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,9 @@ def _update_momentum(update, mu, nu):
336336
else:
337337
updates = jax.tree_util.tree_map(lambda x, v: x + weight_decay * v, updates, params)
338338

339-
step_size = -1.0 * learning_rate_fn(count)
339+
# learning_rate_fn may be a callable schedule or a scalar (e.g. when wrapped
340+
# by optax.inject_hyperparams, it is passed as a pre-evaluated scalar).
341+
step_size = -1.0 * (learning_rate_fn(count) if callable(learning_rate_fn) else learning_rate_fn)
340342
# Finally, fold in step size.
341343
updates = jax.tree_util.tree_map(lambda x: step_size * x, updates)
342344

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 46 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -267,30 +267,44 @@ def _safe_shard(x, pspec):
267267
nnx.update(self.optimizer, optimizer_sharded_state)
268268

269269
def _train_step(self, model, optimizer, inputs):
270-
"""Overrides the main JIT block to natively handle ModelBundle module."""
270+
"""Overrides the main JIT block to natively handle ModelBundle module.
271271
272+
Uses jax.value_and_grad with explicit split/merge to avoid nesting
273+
nnx.value_and_grad inside nnx.jit, which causes Flax NNX to assign
274+
conflicting outer_index values and raises:
275+
ValueError: The graph structure of a node added to cached_partial was
276+
mutated inside the transformation.
277+
"""
272278
batch = self.gen_model_input_fn(inputs)
273-
current_step = model.training_step.value
274-
275-
def loss_wrapper(student, teacher, batch):
276-
if "teacher_output" in batch:
277-
teacher_output = batch["teacher_output"]
278-
else:
279-
teacher_output = self.strategy.teacher_forward_fn(
280-
model=teacher,
281-
input_tokens=batch["input_tokens"],
282-
positions=batch["positions"],
283-
attention_mask=batch.get("attention_mask"),
284-
decoder_segment_ids=batch.get("decoder_segment_ids"),
285-
decoder_target_tokens=batch.get("targets", None),
286-
decoder_target_mask=batch.get("targets_segmentation", None),
287-
cache=None,
288-
)
279+
student = model.student_model
280+
teacher = model.teacher_model
281+
282+
# Run teacher inference outside of value_and_grad.
283+
# The teacher is frozen (stop_gradient), so its output is a constant
284+
# from the perspective of the student gradient computation.
285+
if "teacher_output" in batch:
286+
teacher_output = batch["teacher_output"]
287+
else:
288+
teacher_output = self.strategy.teacher_forward_fn(
289+
model=teacher,
290+
input_tokens=batch["input_tokens"],
291+
positions=batch["positions"],
292+
attention_mask=batch.get("attention_mask"),
293+
decoder_segment_ids=batch.get("decoder_segment_ids"),
294+
decoder_target_tokens=batch.get("targets", None),
295+
decoder_target_mask=batch.get("targets_segmentation", None),
296+
cache=None,
297+
)
298+
teacher_output = jax.tree.map(jax.lax.stop_gradient, teacher_output)
289299

290-
teacher_output = jax.tree.map(jax.lax.stop_gradient, teacher_output)
300+
# Split student into differentiable params and non-differentiable rest.
301+
# Capture graphdef outside of jax.value_and_grad for stable graph tracking.
302+
student_graphdef, diff_params, rest = nnx.split(student, self.wrt_filter, ...)
291303

304+
def loss_wrapper_pure(diff_params, rest):
305+
local_student = nnx.merge(student_graphdef, diff_params, rest, copy=True)
292306
student_output = self.strategy.student_forward_fn(
293-
model=student,
307+
model=local_student,
294308
input_tokens=batch["input_tokens"],
295309
positions=batch["positions"],
296310
attention_mask=batch.get("attention_mask"),
@@ -299,30 +313,27 @@ def loss_wrapper(student, teacher, batch):
299313
decoder_target_mask=batch.get("targets_segmentation", None),
300314
cache=None,
301315
)
302-
# we should apply a mask for labels to disable segment-separator tokens
303316
labels = self.strategy.create_labels(batch["targets"], targets_segmentation=batch.get("targets_segmentation", None))
304-
return self.strategy.compute_loss(student_output, teacher_output, labels, step=current_step)
305-
306-
# Because student is the 0th argument, argnums=0 guarantees
307-
# we only compute gradients for the student.
308-
grad_fn = nnx.value_and_grad(
309-
loss_wrapper,
310-
argnums=nnx.DiffState(0, self.wrt_filter),
311-
has_aux=True,
312-
)
317+
loss, aux = self.strategy.compute_loss(student_output, teacher_output, labels)
318+
# Capture updated non-param state (e.g. RNG counters) from local_student.
319+
_, _, new_rest = nnx.split(local_student, self.wrt_filter, ...)
320+
return loss, (aux, new_rest)
313321

314-
out, grads = grad_fn(model.student_model, model.teacher_model, batch)
322+
grad_fn = jax.value_and_grad(loss_wrapper_pure, argnums=0, has_aux=True)
323+
(loss, (aux, new_rest)), grads = grad_fn(diff_params, rest)
324+
325+
# Propagate updated non-param state back to student.
326+
nnx.update(student, new_rest)
327+
328+
optimizer.update(student, grads)
315329

316330
# Increment step counter after loss computation
317331
model.training_step.value = current_step + 1
318332

319333
tunix_expects_grad_norm = getattr(self, "_tunix_expects_grad_norm", True)
320-
321-
optimizer.update(model.student_model, grads)
322-
323334
if tunix_expects_grad_norm:
324-
return out[0], out[1], optax.global_norm(grads)
325-
return out[0], out[1]
335+
return loss, aux, optax.global_norm(grads)
336+
return loss, aux
326337

327338
def _eval_step(self, model, inputs):
328339
"""Evaluation only needs the student."""

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,42 @@
5656
import pathwaysutils
5757
import tensorflow_datasets as tfds
5858

59+
# JAX 0.9+ changed with_sharding_constraint to assert (not reshard) when all
60+
# mesh axes are Explicit. tpu_inference still expects resharding semantics.
61+
# Patch: try the original (works for Auto axes); on AssertionError (Explicit
62+
# mesh) fall back to jax.sharding.reshard.
63+
_orig_wsc = jax.lax.with_sharding_constraint
64+
65+
66+
def _compat_wsc(x, shardings):
67+
try:
68+
return _orig_wsc(x, shardings)
69+
except AssertionError:
70+
return jax.sharding.reshard(x, shardings)
71+
72+
73+
jax.lax.with_sharding_constraint = _compat_wsc
74+
75+
# tpu_inference JaxEinsum defaults param_dtype=float32, so tpu_inference model weights
76+
# initialize as float32. During weight sync, tunix._apply_dtype_cast then upcasts the
77+
# incoming bfloat16 MaxText weights → float32 to match the target. This leaves v_proj
78+
# as float32 while k_proj output appears bfloat16 (due to k_norm dtype promotion),
79+
# causing a dtype mismatch in the ragged paged attention kernel.
80+
# Fix: skip bfloat16→float32 upcasts during weight sync so synced weights stay bfloat16.
81+
import jax.numpy as _jnp
82+
import tunix.generate.utils as _tunix_utils
83+
84+
_orig_apply_dtype_cast = _tunix_utils._apply_dtype_cast # pylint: disable=protected-access
85+
86+
87+
def _no_bf16_to_f32_cast(val, tgt_dtype, src_key):
88+
if hasattr(val, "dtype") and val.dtype == _jnp.bfloat16 and tgt_dtype == _jnp.float32:
89+
return val # keep bfloat16; tpu_inference model dtype is bfloat16 despite float32 init
90+
return _orig_apply_dtype_cast(val, tgt_dtype, src_key)
91+
92+
93+
_tunix_utils._apply_dtype_cast = _no_bf16_to_f32_cast # pylint: disable=protected-access
94+
5995
from absl import app
6096
from absl import logging as absl_logging
6197
from etils import epath
@@ -543,6 +579,8 @@ def create_rl_components(
543579
"hf_overrides": trainer_config.vllm_hf_overrides,
544580
"enable_expert_parallel": sampler_config.enable_expert_parallel,
545581
"enable_prefix_caching": True, # Enable prefix caching to speed up generation for long prompts
582+
# Ensures vLLM model initializes with correct dtype (not float32 default)
583+
"dtype": trainer_config.weight_dtype,
546584
},
547585
rollout_vllm_sampling_kwargs={
548586
"stop": trainer_config.stop_strings,

src/maxtext/trainers/post_train/sft/train_sft.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,15 @@
3535
eval_interval=-1 steps=10 profiler=xplane weight_dtype=bfloat16
3636
"""
3737

38-
from typing import Sequence
38+
from typing import Any, Sequence
3939

4040
from absl import app
4141
import os
4242
import jax
4343
import optax
4444
import pathwaysutils
4545

46+
from flax import nnx
4647
from flax.linen import partitioning as nn_partitioning
4748

4849
from orbax import checkpoint as ocp
@@ -68,6 +69,69 @@
6869
from maxtext.utils import model_creation_utils
6970

7071

72+
class MaxTextPeftTrainer(peft_trainer.PeftTrainer):
73+
"""MaxText-specific PeftTrainer that avoids nested NNX transformations.
74+
75+
Tunix's default PeftTrainer._train_step creates nnx.value_and_grad inside
76+
nnx.jit. This nesting causes Flax NNX to assign conflicting outer_index
77+
values to graph nodes, resulting in:
78+
ValueError: The graph structure of a node added to cached_partial was
79+
mutated inside the transformation.
80+
81+
This subclass overrides create_train_step_fn to use jax.value_and_grad
82+
with an explicit split/merge pattern (matching MaxText's pre-training NNX
83+
train_step), which avoids the nested NNX transformation issue entirely.
84+
"""
85+
86+
def create_train_step_fn(self):
87+
"""Creates a train step using jax.value_and_grad with explicit NNX split/merge."""
88+
loss_fn_ref = self.loss_fn
89+
has_aux = self._has_aux
90+
gen_fn = self.gen_model_input_fn
91+
is_lora_enabled = self._lora_enabled
92+
wrt = nnx.LoRAParam if is_lora_enabled else nnx.Param
93+
94+
# Capture the graphdef once outside of JIT so that split/merge inside
95+
# jax.value_and_grad can use a stable (non-traced) structural descriptor.
96+
graphdef, _, _ = nnx.split(self.model, wrt, ...)
97+
98+
def train_step(model: nnx.Module, optimizer: nnx.Optimizer, inputs: Any):
99+
inputs = gen_fn(inputs)
100+
101+
# Split model into differentiable params and non-differentiable rest.
102+
# Using jax.value_and_grad (not nnx.value_and_grad) avoids nesting NNX
103+
# transforms inside nnx.jit, which would corrupt outer_index tracking.
104+
_, diff_params, rest = nnx.split(model, wrt, ...)
105+
106+
def loss_wrapper(diff_params, rest, **inputs_kw):
107+
local_model = nnx.merge(graphdef, diff_params, rest, copy=True)
108+
out = loss_fn_ref(local_model, **inputs_kw)
109+
# Capture updated non-param state (e.g. RNG counters) from local_model.
110+
_, _, new_rest = nnx.split(local_model, wrt, ...)
111+
if has_aux:
112+
loss, aux = out
113+
return loss, (aux, new_rest)
114+
else:
115+
return out, (None, new_rest)
116+
117+
grad_fn = jax.value_and_grad(loss_wrapper, argnums=0, has_aux=True)
118+
(out_val, (aux, new_rest)), grads = grad_fn(diff_params, rest, **inputs)
119+
120+
# Propagate updated non-param state (RNG counters, etc.) back to model.
121+
nnx.update(model, new_rest)
122+
123+
# Apply optimizer update. grads has the same nnx.State(wrt) structure
124+
# as diff_params, which is compatible with optimizer.update.
125+
optimizer.update(model, grads)
126+
127+
if has_aux:
128+
return out_val, aux
129+
else:
130+
return out_val, None
131+
132+
return train_step
133+
134+
71135
def get_tunix_config(mt_config):
72136
"""Gets the Tunix training configurations from the MaxText config.
73137
@@ -161,7 +225,7 @@ def setup_trainer_state(mt_config, goodput_recorder=None):
161225
training_hooks = hooks.SFTTrainingHooks(mt_config, mesh, learning_rate_schedule, goodput_recorder)
162226
data_hooks = hooks.SFTDataHooks(mt_config, mesh, goodput_recorder)
163227

164-
trainer = peft_trainer.PeftTrainer(model, optimizer, tunix_config)
228+
trainer = MaxTextPeftTrainer(model, optimizer, tunix_config)
165229
trainer.with_training_hooks(training_hooks)
166230
trainer.with_data_hooks(data_hooks)
167231
trainer = use_maxtext_loss_function(trainer, mt_config)

src/maxtext/utils/maxtext_utils.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1800,26 +1800,41 @@ def print_shardings_params(params, params_sharding, mesh, logical_annotations=No
18001800
"""
18011801
Print state shardings comparing Logical Definition vs Physical Result.
18021802
"""
1803-
if not hasattr(params, "params"):
1804-
params = {"params": params}
1805-
if not hasattr(params_sharding, "params"):
1806-
params_sharding = {"params": params_sharding}
1807-
if logical_annotations and not hasattr(logical_annotations, "params"):
1808-
logical_annotations = {"params": logical_annotations}
1803+
if not isinstance(params, nnx.State):
1804+
if not hasattr(params, "params"):
1805+
params = {"params": params}
1806+
if not hasattr(params_sharding, "params"):
1807+
params_sharding = {"params": params_sharding}
1808+
if logical_annotations and not hasattr(logical_annotations, "params"):
1809+
logical_annotations = {"params": logical_annotations}
18091810

18101811
leaves_params, _ = jax.tree_util.tree_flatten_with_path(params)
18111812
leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(params_sharding)
1812-
leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations)
18131813

1814-
for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip(leaves_params, leaves_sharding, leaves_logical):
1815-
path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path)
1816-
shape = jax.typeof(leaf_val)
1817-
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
1818-
pspec_str = str(tuple(pspec))
1819-
logical_str = str(leaf_logical_val)
1820-
1821-
message = f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}"
1822-
max_logging.info(message)
1814+
if logical_annotations is not None:
1815+
leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations)
1816+
for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip(
1817+
leaves_params, leaves_sharding, leaves_logical
1818+
):
1819+
path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path)
1820+
shape = jax.typeof(leaf_val)
1821+
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
1822+
pspec_str = str(tuple(pspec))
1823+
logical_str = str(leaf_logical_val)
1824+
1825+
message = (
1826+
f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}"
1827+
)
1828+
max_logging.info(message)
1829+
else:
1830+
for (path, leaf_val), (_, leaf_sharding) in zip(leaves_params, leaves_sharding):
1831+
path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path)
1832+
shape = jax.typeof(leaf_val)
1833+
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
1834+
pspec_str = str(tuple(pspec))
1835+
1836+
message = f" {path_str}\n" f" Shape: {shape}\n" f" Physical: {pspec_str}"
1837+
max_logging.info(message)
18231838

18241839
print(flush=True)
18251840

src/maxtext/utils/model_creation_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,13 @@ def create_nnx_model(config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAI
372372
# Get the structure of checkpoint in `config.load_parameters_path`
373373
metadata = ckptr.metadata(config.load_parameters_path)
374374

375+
if metadata is None or metadata.item_metadata is None:
376+
raise ValueError(
377+
f"Cannot read checkpoint metadata from '{config.load_parameters_path}'. "
378+
"The checkpoint directory may be empty or the save did not complete "
379+
"(missing _CHECKPOINT_METADATA). Ensure the checkpoint save finished successfully."
380+
)
381+
375382
is_nnx_checkpoint = True
376383
if (
377384
"params" in metadata.item_metadata.tree.keys()

0 commit comments

Comments
 (0)