Skip to content

Commit 28b1e4a

Browse files
committed
NNX: add sharding tools, Linen<->NNX checkpoint utilities, and post-training fixes
Part 1 — sharding diagnostics and Linen<->NNX checkpoint utilities: - modify print_shardings_params to support NNX (maxtext_utils.py) - add --pure_nnx flag to run_sharding_dump.py - add bidirectional Linen<->NNX checkpoint conversion utility (linen_nnx_converter.py) - add checkpoint comparison utility for Linen vs NNX validation (compare_linen_nnx_checkpoint.py) Part 2 — post-training bug fixes: - models.py: unpack MultimodalInput before passing to NNXDecoder (was passing the whole object as multimodal_input= kwarg; NNXDecoder only accepts individual fields) - optimizers.py: guard adam_pax against scalar LR from optax.inject_hyperparams (callable() check before invoking learning_rate_fn) - train_distill.py: fix nested NNX transform issue (nnx.value_and_grad inside nnx.jit raises conflicting outer_index error); refactored to jax.value_and_grad + explicit nnx.split/merge pattern; teacher inference moved outside value_and_grad
1 parent 2e9d0e9 commit 28b1e4a

14 files changed

Lines changed: 2863 additions & 89 deletions

File tree

src/maxtext/checkpoint_conversion/compare_linen_nnx_checkpoint.py

Lines changed: 609 additions & 0 deletions
Large diffs are not rendered by default.

src/maxtext/checkpoint_conversion/linen_nnx_converter.py

Lines changed: 581 additions & 0 deletions
Large diffs are not rendered by default.

src/maxtext/models/models.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,11 @@ def __call__(
520520
previous_chunk=previous_chunk,
521521
slot=slot,
522522
page_state=page_state,
523-
multimodal_input=multimodal_input,
523+
image_embeddings=multimodal_input.image_embeddings if multimodal_input is not None else None,
524+
image_masks=multimodal_input.image_masks if multimodal_input is not None else None,
525+
audio_embeddings=multimodal_input.audio_embeddings if multimodal_input is not None else None,
526+
audio_masks=multimodal_input.audio_masks if multimodal_input is not None else None,
527+
bidirectional_mask=multimodal_input.bidirectional_mask if multimodal_input is not None else None,
524528
kv_caches=kv_caches,
525529
attention_metadata=attention_metadata,
526530
deepstack_visual_embeds=deepstack_visual_embeds,

src/maxtext/optimizers/optimizers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,9 @@ def _update_momentum(update, mu, nu):
336336
else:
337337
updates = jax.tree_util.tree_map(lambda x, v: x + weight_decay * v, updates, params)
338338

339-
step_size = -1.0 * learning_rate_fn(count)
339+
# learning_rate_fn may be a callable schedule or a scalar (e.g. when wrapped
340+
# by optax.inject_hyperparams, it is passed as a pre-evaluated scalar).
341+
step_size = -1.0 * (learning_rate_fn(count) if callable(learning_rate_fn) else learning_rate_fn)
340342
# Finally, fold in step size.
341343
updates = jax.tree_util.tree_map(lambda x: step_size * x, updates)
342344

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -251,30 +251,45 @@ def wrt_filter(path, x):
251251
# Inherits _shard_optimizer from PeftTrainer.
252252

253253
def _train_step(self, model, optimizer, inputs):
254-
"""Overrides the main JIT block to natively handle ModelBundle module."""
254+
"""Overrides the main JIT block to natively handle ModelBundle module.
255255
256+
Uses jax.value_and_grad with explicit split/merge to avoid nesting
257+
nnx.value_and_grad inside nnx.jit, which causes Flax NNX to assign
258+
conflicting outer_index values and raises:
259+
ValueError: The graph structure of a node added to cached_partial was
260+
mutated inside the transformation.
261+
"""
256262
batch = self.gen_model_input_fn(inputs)
263+
student = model.student_model
264+
teacher = model.teacher_model
257265
current_step = model.training_step.value
258266

259-
def loss_wrapper(student, teacher, batch):
260-
if "teacher_output" in batch:
261-
teacher_output = batch["teacher_output"]
262-
else:
263-
teacher_output = self.strategy.teacher_forward_fn(
264-
model=teacher,
265-
input_tokens=batch["input_tokens"],
266-
positions=batch["positions"],
267-
attention_mask=batch.get("attention_mask"),
268-
decoder_segment_ids=batch.get("decoder_segment_ids"),
269-
decoder_target_tokens=batch.get("targets", None),
270-
decoder_target_mask=batch.get("targets_segmentation", None),
271-
cache=None,
272-
)
267+
# Run teacher inference outside of value_and_grad.
268+
# The teacher is frozen (stop_gradient), so its output is a constant
269+
# from the perspective of the student gradient computation.
270+
if "teacher_output" in batch:
271+
teacher_output = batch["teacher_output"]
272+
else:
273+
teacher_output = self.strategy.teacher_forward_fn(
274+
model=teacher,
275+
input_tokens=batch["input_tokens"],
276+
positions=batch["positions"],
277+
attention_mask=batch.get("attention_mask"),
278+
decoder_segment_ids=batch.get("decoder_segment_ids"),
279+
decoder_target_tokens=batch.get("targets", None),
280+
decoder_target_mask=batch.get("targets_segmentation", None),
281+
cache=None,
282+
)
283+
teacher_output = jax.tree.map(jax.lax.stop_gradient, teacher_output)
273284

274-
teacher_output = jax.tree.map(jax.lax.stop_gradient, teacher_output)
285+
# Split student into differentiable params and non-differentiable rest.
286+
# Capture graphdef outside of jax.value_and_grad for stable graph tracking.
287+
student_graphdef, diff_params, rest = nnx.split(student, self.wrt_filter, ...)
275288

289+
def loss_wrapper_pure(diff_params, rest):
290+
local_student = nnx.merge(student_graphdef, diff_params, rest, copy=True)
276291
student_output = self.strategy.student_forward_fn(
277-
model=student,
292+
model=local_student,
278293
input_tokens=batch["input_tokens"],
279294
positions=batch["positions"],
280295
attention_mask=batch.get("attention_mask"),
@@ -283,30 +298,27 @@ def loss_wrapper(student, teacher, batch):
283298
decoder_target_mask=batch.get("targets_segmentation", None),
284299
cache=None,
285300
)
286-
# we should apply a mask for labels to disable segment-separator tokens
287301
labels = self.strategy.create_labels(batch["targets"], targets_segmentation=batch.get("targets_segmentation", None))
288-
return self.strategy.compute_loss(student_output, teacher_output, labels, step=current_step)
289-
290-
# Because student is the 0th argument, argnums=0 guarantees
291-
# we only compute gradients for the student.
292-
grad_fn = nnx.value_and_grad(
293-
loss_wrapper,
294-
argnums=nnx.DiffState(0, self.wrt_filter),
295-
has_aux=True,
296-
)
302+
loss, aux = self.strategy.compute_loss(student_output, teacher_output, labels, step=current_step)
303+
# Capture updated non-param state (e.g. RNG counters) from local_student.
304+
_, _, new_rest = nnx.split(local_student, self.wrt_filter, ...)
305+
return loss, (aux, new_rest)
297306

298-
out, grads = grad_fn(model.student_model, model.teacher_model, batch)
307+
grad_fn = jax.value_and_grad(loss_wrapper_pure, argnums=0, has_aux=True)
308+
(loss, (aux, new_rest)), grads = grad_fn(diff_params, rest)
309+
310+
# Propagate updated non-param state back to student.
311+
nnx.update(student, new_rest)
312+
313+
optimizer.update(student, grads)
299314

300315
# Increment step counter after loss computation
301316
model.training_step.value = current_step + 1
302317

303318
tunix_expects_grad_norm = getattr(self, "_tunix_expects_grad_norm", True)
304-
305-
optimizer.update(model.student_model, grads)
306-
307319
if tunix_expects_grad_norm:
308-
return out[0], out[1], optax.global_norm(grads)
309-
return out[0], out[1]
320+
return loss, aux, optax.global_norm(grads)
321+
return loss, aux
310322

311323
def _eval_step(self, model, inputs):
312324
"""Evaluation only needs the student."""

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,42 @@
5555
import pathwaysutils
5656
import tensorflow_datasets as tfds
5757

58+
# JAX 0.9+ changed with_sharding_constraint to assert (not reshard) when all
59+
# mesh axes are Explicit. tpu_inference still expects resharding semantics.
60+
# Patch: try the original (works for Auto axes); on AssertionError (Explicit
61+
# mesh) fall back to jax.sharding.reshard.
62+
_orig_wsc = jax.lax.with_sharding_constraint
63+
64+
65+
def _compat_wsc(x, shardings):
66+
try:
67+
return _orig_wsc(x, shardings)
68+
except AssertionError:
69+
return jax.sharding.reshard(x, shardings)
70+
71+
72+
jax.lax.with_sharding_constraint = _compat_wsc
73+
74+
# tpu_inference JaxEinsum defaults param_dtype=float32, so tpu_inference model weights
75+
# initialize as float32. During weight sync, tunix._apply_dtype_cast then upcasts the
76+
# incoming bfloat16 MaxText weights → float32 to match the target. This leaves v_proj
77+
# as float32 while k_proj output appears bfloat16 (due to k_norm dtype promotion),
78+
# causing a dtype mismatch in the ragged paged attention kernel.
79+
# Fix: skip bfloat16→float32 upcasts during weight sync so synced weights stay bfloat16.
80+
import jax.numpy as _jnp
81+
import tunix.generate.utils as _tunix_utils
82+
83+
_orig_apply_dtype_cast = _tunix_utils._apply_dtype_cast # pylint: disable=protected-access
84+
85+
86+
def _no_bf16_to_f32_cast(val, tgt_dtype, src_key):
87+
if hasattr(val, "dtype") and val.dtype == _jnp.bfloat16 and tgt_dtype == _jnp.float32:
88+
return val # keep bfloat16; tpu_inference model dtype is bfloat16 despite float32 init
89+
return _orig_apply_dtype_cast(val, tgt_dtype, src_key)
90+
91+
92+
_tunix_utils._apply_dtype_cast = _no_bf16_to_f32_cast # pylint: disable=protected-access
93+
5894
from absl import app
5995
from absl import logging as absl_logging
6096
from etils import epath
@@ -421,6 +457,8 @@ def create_rl_components(
421457
"hf_overrides": trainer_config.vllm_hf_overrides,
422458
"enable_expert_parallel": sampler_config.enable_expert_parallel,
423459
"enable_prefix_caching": True, # Enable prefix caching to speed up generation for long prompts
460+
# Ensures vLLM model initializes with correct dtype (not float32 default)
461+
"dtype": trainer_config.weight_dtype,
424462
},
425463
rollout_vllm_sampling_kwargs={
426464
"stop": trainer_config.stop_strings,
@@ -568,7 +606,10 @@ def rl_train(argv: Sequence[str], kwargs: dict):
568606
max_train_steps = get_max_train_steps(trainer_config)
569607

570608
# Create model tokenizer
571-
model_tokenizer = AutoTokenizer.from_pretrained(trainer_config.tokenizer_path)
609+
model_tokenizer = AutoTokenizer.from_pretrained(
610+
trainer_config.tokenizer_path,
611+
token=trainer_config.hf_access_token or None,
612+
)
572613

573614
train_dataset, test_dataset = prepare_datasets(trainer_config, model_tokenizer)
574615

src/maxtext/trainers/post_train/sft/train_sft.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,15 @@
3535
eval_interval=-1 steps=10 profiler=xplane weight_dtype=bfloat16
3636
"""
3737

38-
from typing import Sequence
38+
from typing import Any, Sequence
3939

4040
from absl import app
4141
import os
4242
import jax
4343
import optax
4444
import pathwaysutils
4545

46+
from flax import nnx
4647
from flax.linen import partitioning as nn_partitioning
4748

4849
from orbax import checkpoint as ocp
@@ -68,6 +69,70 @@
6869
from maxtext.utils import model_creation_utils
6970

7071

72+
class MaxTextPeftTrainer(peft_trainer.PeftTrainer):
73+
"""MaxText-specific PeftTrainer that avoids nested NNX transformations.
74+
75+
Tunix's default PeftTrainer._train_step creates nnx.value_and_grad inside
76+
nnx.jit. This nesting causes Flax NNX to assign conflicting outer_index
77+
values to graph nodes, resulting in:
78+
ValueError: The graph structure of a node added to cached_partial was
79+
mutated inside the transformation.
80+
81+
This subclass overrides create_train_step_fn to use jax.value_and_grad
82+
with an explicit split/merge pattern (matching MaxText's pre-training NNX
83+
train_step), which avoids the nested NNX transformation issue entirely.
84+
"""
85+
86+
def create_train_step_fn(self):
87+
"""Creates a train step using jax.value_and_grad with explicit NNX split/merge."""
88+
loss_fn_ref = self.loss_fn
89+
has_aux = self._has_aux
90+
gen_fn = self.gen_model_input_fn
91+
is_lora_enabled = self._lora_enabled
92+
wrt = nnx.LoRAParam if is_lora_enabled else nnx.Param
93+
tunix_expects_grad_norm = getattr(self, "_tunix_expects_grad_norm", True)
94+
95+
# Capture the graphdef once outside of JIT so that split/merge inside
96+
# jax.value_and_grad can use a stable (non-traced) structural descriptor.
97+
graphdef, _, _ = nnx.split(self.model, wrt, ...)
98+
99+
def train_step(model: nnx.Module, optimizer: nnx.Optimizer, inputs: Any):
100+
inputs = gen_fn(inputs)
101+
102+
# Split model into differentiable params and non-differentiable rest.
103+
# Using jax.value_and_grad (not nnx.value_and_grad) avoids nesting NNX
104+
# transforms inside nnx.jit, which would corrupt outer_index tracking.
105+
_, diff_params, rest = nnx.split(model, wrt, ...)
106+
107+
def loss_wrapper(diff_params, rest, **inputs_kw):
108+
local_model = nnx.merge(graphdef, diff_params, rest, copy=True)
109+
out = loss_fn_ref(local_model, **inputs_kw)
110+
# Capture updated non-param state (e.g. RNG counters) from local_model.
111+
_, _, new_rest = nnx.split(local_model, wrt, ...)
112+
if has_aux:
113+
loss, aux = out
114+
return loss, (aux, new_rest)
115+
else:
116+
return out, (None, new_rest)
117+
118+
grad_fn = jax.value_and_grad(loss_wrapper, argnums=0, has_aux=True)
119+
(out_val, (aux, new_rest)), grads = grad_fn(diff_params, rest, **inputs)
120+
121+
# Propagate updated non-param state (RNG counters, etc.) back to model.
122+
nnx.update(model, new_rest)
123+
124+
# Apply optimizer update. grads has the same nnx.State(wrt) structure
125+
# as diff_params, which is compatible with optimizer.update.
126+
optimizer.update(model, grads)
127+
128+
aux_out = aux if has_aux else None
129+
if tunix_expects_grad_norm:
130+
return out_val, aux_out, optax.global_norm(grads)
131+
return out_val, aux_out
132+
133+
return train_step
134+
135+
71136
def get_tunix_config(mt_config):
72137
"""Gets the Tunix training configurations from the MaxText config.
73138
@@ -109,6 +174,7 @@ def get_tunix_config(mt_config):
109174
checkpointing_options=checkpointing_options,
110175
metrics_logging_options=metrics_logging_options,
111176
profiler_options=profiler_options,
177+
data_sharding_axis=tuple(mt_config.data_sharding),
112178
)
113179

114180

@@ -162,7 +228,7 @@ def setup_trainer_state(mt_config, goodput_recorder=None):
162228
data_hooks = hooks.SFTDataHooks(mt_config, mesh, goodput_recorder)
163229
# Provide rules context so 'norm' is translated to mesh axes during maybe_restore
164230
with nn_partitioning.axis_rules(mt_config.logical_axis_rules):
165-
trainer = peft_trainer.PeftTrainer(model, optimizer, tunix_config)
231+
trainer = MaxTextPeftTrainer(model, optimizer, tunix_config)
166232
trainer.with_training_hooks(training_hooks)
167233
trainer.with_data_hooks(data_hooks)
168234
trainer = use_maxtext_loss_function(trainer, mt_config)

src/maxtext/utils/maxtext_utils.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1852,26 +1852,41 @@ def print_shardings_params(params, params_sharding, mesh, logical_annotations=No
18521852
"""
18531853
Print state shardings comparing Logical Definition vs Physical Result.
18541854
"""
1855-
if not hasattr(params, "params"):
1856-
params = {"params": params}
1857-
if not hasattr(params_sharding, "params"):
1858-
params_sharding = {"params": params_sharding}
1859-
if logical_annotations and not hasattr(logical_annotations, "params"):
1860-
logical_annotations = {"params": logical_annotations}
1855+
if not isinstance(params, nnx.State):
1856+
if not hasattr(params, "params"):
1857+
params = {"params": params}
1858+
if not hasattr(params_sharding, "params"):
1859+
params_sharding = {"params": params_sharding}
1860+
if logical_annotations and not hasattr(logical_annotations, "params"):
1861+
logical_annotations = {"params": logical_annotations}
18611862

18621863
leaves_params, _ = jax.tree_util.tree_flatten_with_path(params)
18631864
leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(params_sharding)
1864-
leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations)
18651865

1866-
for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip(leaves_params, leaves_sharding, leaves_logical):
1867-
path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path)
1868-
shape = jax.typeof(leaf_val)
1869-
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
1870-
pspec_str = str(tuple(pspec))
1871-
logical_str = str(leaf_logical_val)
1872-
1873-
message = f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}"
1874-
max_logging.info(message)
1866+
if logical_annotations is not None:
1867+
leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations)
1868+
for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip(
1869+
leaves_params, leaves_sharding, leaves_logical
1870+
):
1871+
path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path)
1872+
shape = jax.typeof(leaf_val)
1873+
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
1874+
pspec_str = str(tuple(pspec))
1875+
logical_str = str(leaf_logical_val)
1876+
1877+
message = (
1878+
f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}"
1879+
)
1880+
max_logging.info(message)
1881+
else:
1882+
for (path, leaf_val), (_, leaf_sharding) in zip(leaves_params, leaves_sharding):
1883+
path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path)
1884+
shape = jax.typeof(leaf_val)
1885+
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
1886+
pspec_str = str(tuple(pspec))
1887+
1888+
message = f" {path_str}\n" f" Shape: {shape}\n" f" Physical: {pspec_str}"
1889+
max_logging.info(message)
18751890

18761891
print(flush=True)
18771892

src/maxtext/utils/model_creation_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,13 @@ def from_pretrained(
546546
# Get the structure of checkpoint in `config.load_parameters_path`
547547
metadata = ckptr.metadata(config.load_parameters_path)
548548

549+
if metadata is None or metadata.item_metadata is None:
550+
raise ValueError(
551+
f"Cannot read checkpoint metadata from '{config.load_parameters_path}'. "
552+
"The checkpoint directory may be empty or the save did not complete "
553+
"(missing _CHECKPOINT_METADATA). Ensure the checkpoint save finished successfully."
554+
)
555+
549556
def _adjust_target_for_moe_fusion(target, meta_tree, is_nnx):
550557
if not hasattr(target, "items") or not hasattr(meta_tree, "items"):
551558
return target

0 commit comments

Comments
 (0)