Skip to content

Commit 6466f8b

Browse files
Merge pull request #3652 from AI-Hypercomputer:feat/nnx-post-train-fixes
PiperOrigin-RevId: 915486436
2 parents d771e71 + 2eab581 commit 6466f8b

10 files changed

Lines changed: 363 additions & 94 deletions

File tree

src/maxtext/optimizers/optimizers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,9 @@ def _update_momentum(update, mu, nu):
336336
else:
337337
updates = jax.tree_util.tree_map(lambda x, v: x + weight_decay * v, updates, params)
338338

339-
step_size = -1.0 * learning_rate_fn(count)
339+
# learning_rate_fn may be a callable schedule or a scalar (e.g. when wrapped
340+
# by optax.inject_hyperparams, it is passed as a pre-evaluated scalar).
341+
step_size = -1.0 * (learning_rate_fn(count) if callable(learning_rate_fn) else learning_rate_fn)
340342
# Finally, fold in step size.
341343
updates = jax.tree_util.tree_map(lambda x: step_size * x, updates)
342344

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -274,30 +274,45 @@ def wrt_filter(path, x):
274274
# Inherits _shard_optimizer from PeftTrainer.
275275

276276
def _train_step(self, model, optimizer, inputs):
277-
"""Overrides the main JIT block to natively handle ModelBundle module."""
277+
"""Overrides the main JIT block to natively handle ModelBundle module.
278278
279+
Uses jax.value_and_grad with explicit split/merge to avoid nesting
280+
nnx.value_and_grad inside nnx.jit, which causes Flax NNX to assign
281+
conflicting outer_index values and raises:
282+
ValueError: The graph structure of a node added to cached_partial was
283+
mutated inside the transformation.
284+
"""
279285
batch = self.gen_model_input_fn(inputs)
286+
student = model.student_model
287+
teacher = model.teacher_model
280288
current_step = model.training_step[...]
281289

282-
def loss_wrapper(student, teacher, batch):
283-
if "teacher_output" in batch:
284-
teacher_output = batch["teacher_output"]
285-
else:
286-
teacher_output = self.strategy.teacher_forward_fn(
287-
model=teacher,
288-
input_tokens=batch["input_tokens"],
289-
positions=batch["positions"],
290-
attention_mask=batch.get("attention_mask"),
291-
decoder_segment_ids=batch.get("decoder_segment_ids"),
292-
decoder_target_tokens=batch.get("targets", None),
293-
decoder_target_mask=batch.get("targets_segmentation", None),
294-
cache=None,
295-
)
290+
# Run teacher inference outside of value_and_grad.
291+
# The teacher is frozen (stop_gradient), so its output is a constant
292+
# from the perspective of the student gradient computation.
293+
if "teacher_output" in batch:
294+
teacher_output = batch["teacher_output"]
295+
else:
296+
teacher_output = self.strategy.teacher_forward_fn(
297+
model=teacher,
298+
input_tokens=batch["input_tokens"],
299+
positions=batch["positions"],
300+
attention_mask=batch.get("attention_mask"),
301+
decoder_segment_ids=batch.get("decoder_segment_ids"),
302+
decoder_target_tokens=batch.get("targets", None),
303+
decoder_target_mask=batch.get("targets_segmentation", None),
304+
cache=None,
305+
)
306+
teacher_output = jax.tree.map(jax.lax.stop_gradient, teacher_output)
296307

297-
teacher_output = jax.tree.map(jax.lax.stop_gradient, teacher_output)
308+
# Split student into differentiable params and non-differentiable rest.
309+
# Capture graphdef outside of jax.value_and_grad for stable graph tracking.
310+
student_graphdef, diff_params, rest = nnx.split(student, self.wrt_filter, ...)
298311

312+
def loss_wrapper_pure(diff_params, rest):
313+
local_student = nnx.merge(student_graphdef, diff_params, rest, copy=True)
299314
student_output = self.strategy.student_forward_fn(
300-
model=student,
315+
model=local_student,
301316
input_tokens=batch["input_tokens"],
302317
positions=batch["positions"],
303318
attention_mask=batch.get("attention_mask"),
@@ -306,29 +321,26 @@ def loss_wrapper(student, teacher, batch):
306321
decoder_target_mask=batch.get("targets_segmentation", None),
307322
cache=None,
308323
)
309-
# we should apply a mask for labels to disable segment-separator tokens
310324
labels = self.strategy.create_labels(batch["targets"], targets_segmentation=batch.get("targets_segmentation", None))
311-
return self.strategy.compute_loss(student_output, teacher_output, labels, step=current_step)
312-
313-
# Because student is the 0th argument, argnums=0 guarantees
314-
# we only compute gradients for the student.
315-
grad_fn = nnx.value_and_grad(
316-
loss_wrapper,
317-
argnums=nnx.DiffState(0, self.wrt_filter),
318-
has_aux=True,
319-
)
325+
loss, aux = self.strategy.compute_loss(student_output, teacher_output, labels, step=current_step)
326+
# Capture updated non-param state (e.g. RNG counters) from local_student.
327+
_, _, new_rest = nnx.split(local_student, self.wrt_filter, ...)
328+
return loss, (aux, new_rest)
320329

321-
out, grads = grad_fn(model.student_model, model.teacher_model, batch)
330+
grad_fn = jax.value_and_grad(loss_wrapper_pure, argnums=0, has_aux=True)
331+
(loss, (aux, new_rest)), grads = grad_fn(diff_params, rest)
322332

323-
model.training_step.set_value(current_step + 1)
333+
# Propagate updated non-param state back to student.
334+
nnx.update(student, new_rest)
324335

325-
tunix_expects_grad_norm = getattr(self, "_tunix_expects_grad_norm", True)
336+
optimizer.update(student, grads)
326337

327-
optimizer.update(model.student_model, grads)
338+
model.training_step.set_value(current_step + 1)
328339

340+
tunix_expects_grad_norm = getattr(self, "_tunix_expects_grad_norm", True)
329341
if tunix_expects_grad_norm:
330-
return out[0], out[1], optax.global_norm(grads)
331-
return out[0], out[1]
342+
return loss, aux, optax.global_norm(grads)
343+
return loss, aux
332344

333345
def _eval_step(self, model, inputs):
334346
"""Evaluation only needs the student."""

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,14 @@
4444
"""
4545

4646
from __future__ import annotations
47+
import contextlib
4748
from functools import wraps
4849
from typing import Any, Optional, Sequence
4950

5051
import datasets
5152
import grain
5253
import jax
54+
import jax.numpy as jnp
5355
import json
5456
import logging
5557
import os
@@ -67,6 +69,48 @@
6769
from tunix.rl.rollout import base_rollout
6870
from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner
6971
from tunix.sft import metrics_logger, profiler
72+
import tunix.generate.utils as tunix_utils
73+
74+
75+
@contextlib.contextmanager
76+
def _tpu_inference_compat_patches():
77+
"""Tactical compat shims for tpu_inference.
78+
79+
tpu_inference has two call-site assumptions that no longer hold:
80+
1. jax.lax.with_sharding_constraint: assumes silent reshard on mismatch,
81+
but current jax asserts when all mesh axes are Explicit. Fall back to
82+
jax.sharding.reshard on the AssertionError.
83+
2. tunix._apply_dtype_cast: tpu_inference JaxEinsum defaults
84+
param_dtype=float32 so its weights initialize as float32, but model
85+
dtype is bfloat16; the cast upgraded synced bfloat16 weights to float32,
86+
which then mismatched in the ragged paged attention kernel. Skip the
87+
bf16->f32 upcast so synced weights stay bfloat16.
88+
89+
Scoped to rl_train() so the patches don't leak into other importers of this
90+
module. Drop both once tpu_inference is updated upstream.
91+
"""
92+
orig_wsc = jax.lax.with_sharding_constraint
93+
orig_apply_dtype_cast = tunix_utils._apply_dtype_cast # pylint: disable=protected-access
94+
95+
def _compat_wsc(x, shardings):
96+
try:
97+
return orig_wsc(x, shardings)
98+
except AssertionError:
99+
return jax.sharding.reshard(x, shardings)
100+
101+
def _no_bf16_to_f32_cast(val, tgt_dtype, src_key):
102+
if hasattr(val, "dtype") and val.dtype == jnp.bfloat16 and tgt_dtype == jnp.float32:
103+
return val
104+
return orig_apply_dtype_cast(val, tgt_dtype, src_key)
105+
106+
jax.lax.with_sharding_constraint = _compat_wsc
107+
tunix_utils._apply_dtype_cast = _no_bf16_to_f32_cast # pylint: disable=protected-access
108+
try:
109+
yield
110+
finally:
111+
jax.lax.with_sharding_constraint = orig_wsc
112+
tunix_utils._apply_dtype_cast = orig_apply_dtype_cast # pylint: disable=protected-access
113+
70114

71115
os.environ["TOKENIZERS_PARALLELISM"] = "0"
72116

@@ -418,6 +462,8 @@ def create_rl_components(
418462
"hf_overrides": trainer_config.vllm_hf_overrides,
419463
"enable_expert_parallel": sampler_config.enable_expert_parallel,
420464
"enable_prefix_caching": True, # Enable prefix caching to speed up generation for long prompts
465+
# Ensures vLLM model initializes with correct dtype (not float32 default)
466+
"dtype": trainer_config.weight_dtype,
421467
},
422468
rollout_vllm_sampling_kwargs={
423469
"stop": trainer_config.stop_strings,
@@ -539,6 +585,12 @@ def rl_train(argv: Sequence[str], kwargs: dict):
539585
trainer_devices: JAX devices for the trainer.
540586
sampler_devices: JAX devices for the sampler.
541587
"""
588+
with _tpu_inference_compat_patches():
589+
_rl_train_impl(argv, kwargs)
590+
591+
592+
def _rl_train_impl(argv: Sequence[str], kwargs: dict):
593+
"""rl_train body — kept separate so _tpu_inference_compat_patches wraps it cleanly."""
542594
trainer_config, sampler_config, trainer_devices, sampler_devices = model_creation_utils.setup_configs_and_devices(
543595
argv, kwargs
544596
)
@@ -563,7 +615,10 @@ def rl_train(argv: Sequence[str], kwargs: dict):
563615
max_train_steps = get_max_train_steps(trainer_config)
564616

565617
# Create model tokenizer
566-
model_tokenizer = AutoTokenizer.from_pretrained(trainer_config.tokenizer_path)
618+
model_tokenizer = AutoTokenizer.from_pretrained(
619+
trainer_config.tokenizer_path,
620+
token=trainer_config.hf_access_token or None,
621+
)
567622

568623
train_dataset, test_dataset = prepare_datasets(trainer_config, model_tokenizer)
569624

src/maxtext/trainers/post_train/sft/train_sft.py

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,16 @@
3535
eval_interval=-1 steps=10 profiler=xplane weight_dtype=bfloat16
3636
"""
3737

38-
from typing import Sequence
38+
import inspect
39+
from typing import Any, Sequence
3940

4041
from absl import app
4142
import os
4243
import jax
4344
import optax
4445
import pathwaysutils
4546

47+
from flax import nnx
4648
from flax.linen import partitioning as nn_partitioning
4749

4850
from orbax import checkpoint as ocp
@@ -69,6 +71,78 @@
6971
from maxtext.utils import model_creation_utils
7072

7173

74+
class MaxTextPeftTrainer(peft_trainer.PeftTrainer):
75+
"""MaxText-specific PeftTrainer that avoids nested NNX transformations.
76+
77+
Tunix's default PeftTrainer._train_step creates nnx.value_and_grad inside
78+
nnx.jit. This nesting causes Flax NNX to assign conflicting outer_index
79+
values to graph nodes, resulting in:
80+
ValueError: The graph structure of a node added to cached_partial was
81+
mutated inside the transformation.
82+
83+
This subclass overrides create_train_step_fn to use jax.value_and_grad
84+
with an explicit split/merge pattern (matching MaxText's pre-training NNX
85+
train_step), which avoids the nested NNX transformation issue entirely.
86+
"""
87+
88+
def create_train_step_fn(self):
89+
"""Creates a train step using jax.value_and_grad with explicit NNX split/merge."""
90+
loss_fn_ref = self.loss_fn
91+
has_aux = self._has_aux
92+
gen_fn = self.gen_model_input_fn
93+
is_lora_enabled = self._lora_enabled
94+
wrt = nnx.LoRAParam if is_lora_enabled else nnx.Param
95+
96+
# Detect whether Tunix's train() expects (loss, aux, grad_norm) or just
97+
# (loss, aux) by inspecting the source of PeftTrainer._train_step.
98+
tunix_expects_grad_norm = False
99+
try:
100+
source = inspect.getsource(peft_trainer.PeftTrainer._train_step) # pylint: disable=protected-access
101+
tunix_expects_grad_norm = "grad_norm" in source
102+
except (TypeError, OSError):
103+
pass
104+
105+
# Capture the graphdef once outside of JIT so that split/merge inside
106+
# jax.value_and_grad can use a stable (non-traced) structural descriptor.
107+
graphdef, _, _ = nnx.split(self.model, wrt, ...)
108+
109+
def train_step(model: nnx.Module, optimizer: nnx.Optimizer, inputs: Any):
110+
inputs = gen_fn(inputs)
111+
112+
# Split model into differentiable params and non-differentiable rest.
113+
# Using jax.value_and_grad (not nnx.value_and_grad) avoids nesting NNX
114+
# transforms inside nnx.jit, which would corrupt outer_index tracking.
115+
_, diff_params, rest = nnx.split(model, wrt, ...)
116+
117+
def loss_wrapper(diff_params, rest, **inputs_kw):
118+
local_model = nnx.merge(graphdef, diff_params, rest, copy=True)
119+
out = loss_fn_ref(local_model, **inputs_kw)
120+
# Capture updated non-param state (e.g. RNG counters) from local_model.
121+
_, _, new_rest = nnx.split(local_model, wrt, ...)
122+
if has_aux:
123+
loss, aux = out
124+
return loss, (aux, new_rest)
125+
else:
126+
return out, (None, new_rest)
127+
128+
grad_fn = jax.value_and_grad(loss_wrapper, argnums=0, has_aux=True)
129+
(out_val, (aux, new_rest)), grads = grad_fn(diff_params, rest, **inputs)
130+
131+
# Propagate updated non-param state (RNG counters, etc.) back to model.
132+
nnx.update(model, new_rest)
133+
134+
# Apply optimizer update. grads has the same nnx.State(wrt) structure
135+
# as diff_params, which is compatible with optimizer.update.
136+
optimizer.update(model, grads)
137+
138+
aux_out = aux if has_aux else None
139+
if tunix_expects_grad_norm:
140+
return out_val, aux_out, optax.global_norm(grads)
141+
return out_val, aux_out
142+
143+
return train_step
144+
145+
72146
def get_tunix_config(mt_config):
73147
"""Gets the Tunix training configurations from the MaxText config.
74148
@@ -110,6 +184,7 @@ def get_tunix_config(mt_config):
110184
checkpointing_options=checkpointing_options,
111185
metrics_logging_options=metrics_logging_options,
112186
profiler_options=profiler_options,
187+
data_sharding_axis=tuple(mt_config.data_sharding),
113188
)
114189

115190

@@ -176,10 +251,9 @@ def setup_trainer_state(mt_config, goodput_recorder=None):
176251

177252
# Provide rules context so 'norm' is translated to mesh axes during maybe_restore
178253
with nn_partitioning.axis_rules(mt_config.logical_axis_rules):
179-
trainer = peft_trainer.PeftTrainer(model, optimizer, tunix_config)
254+
trainer = MaxTextPeftTrainer(model, optimizer, tunix_config)
180255
if mt_config.lora.lora_restore_path:
181256
trainer = lora_utils.restore_lora_from_path(trainer, mt_config)
182-
183257
trainer.with_training_hooks(training_hooks)
184258
trainer.with_data_hooks(data_hooks)
185259
trainer = use_maxtext_loss_function(trainer, mt_config)

src/maxtext/utils/maxtext_utils.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1910,26 +1910,41 @@ def print_shardings_params(params, params_sharding, mesh, logical_annotations=No
19101910
"""
19111911
Print state shardings comparing Logical Definition vs Physical Result.
19121912
"""
1913-
if not hasattr(params, "params"):
1914-
params = {"params": params}
1915-
if not hasattr(params_sharding, "params"):
1916-
params_sharding = {"params": params_sharding}
1917-
if logical_annotations and not hasattr(logical_annotations, "params"):
1918-
logical_annotations = {"params": logical_annotations}
1913+
if not isinstance(params, nnx.State):
1914+
if not hasattr(params, "params"):
1915+
params = {"params": params}
1916+
if not hasattr(params_sharding, "params"):
1917+
params_sharding = {"params": params_sharding}
1918+
if logical_annotations and not hasattr(logical_annotations, "params"):
1919+
logical_annotations = {"params": logical_annotations}
19191920

19201921
leaves_params, _ = jax.tree_util.tree_flatten_with_path(params)
19211922
leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(params_sharding)
1922-
leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations)
19231923

1924-
for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip(leaves_params, leaves_sharding, leaves_logical):
1925-
path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path)
1926-
shape = jax.typeof(leaf_val)
1927-
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
1928-
pspec_str = str(tuple(pspec))
1929-
logical_str = str(leaf_logical_val)
1930-
1931-
message = f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}"
1932-
max_logging.info(message)
1924+
if logical_annotations is not None:
1925+
leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations)
1926+
for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip(
1927+
leaves_params, leaves_sharding, leaves_logical
1928+
):
1929+
path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path)
1930+
shape = jax.typeof(leaf_val)
1931+
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
1932+
pspec_str = str(tuple(pspec))
1933+
logical_str = str(leaf_logical_val)
1934+
1935+
message = (
1936+
f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}"
1937+
)
1938+
max_logging.info(message)
1939+
else:
1940+
for (path, leaf_val), (_, leaf_sharding) in zip(leaves_params, leaves_sharding):
1941+
path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path)
1942+
shape = jax.typeof(leaf_val)
1943+
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
1944+
pspec_str = str(tuple(pspec))
1945+
1946+
message = f" {path_str}\n" f" Shape: {shape}\n" f" Physical: {pspec_str}"
1947+
max_logging.info(message)
19331948

19341949
print(flush=True)
19351950

0 commit comments

Comments
 (0)