Skip to content

Commit bbfcef6

Browse files
Merge pull request #3188 from AI-Hypercomputer:hengtaoguo-nnx-logits
PiperOrigin-RevId: 907034928
2 parents 87ed81f + 9577539 commit bbfcef6

3 files changed

Lines changed: 123 additions & 11 deletions

File tree

src/maxtext/utils/model_creation_utils.py

Lines changed: 59 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import numpy as np
3232
from jax.sharding import Mesh
3333
from maxtext.configs import pyconfig
34-
from maxtext.common.common_types import MODEL_MODE_TRAIN
34+
from maxtext.common.common_types import MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_TRAIN
3535
from maxtext.layers import quantizations
3636
from maxtext.models import models
3737
from maxtext.utils import max_logging
@@ -580,20 +580,21 @@ def _adjust_target_for_moe_fusion(target, meta_tree, is_nnx):
580580
}
581581
}
582582
else:
583-
# structure of nnx checkpoint: {'decoder': {'value': ...}}
583+
# NNX checkpoint: {'decoder': {'value': ...}}, or NNX-RL with extra 'base' nesting.
584+
# Restore only nnx.Param — RNG variable shapes may differ between checkpoint and model.
584585
target_for_restore = jax.tree.map(
585586
lambda v: {"value": v.value},
586587
sharded_state,
587588
is_leaf=lambda n: isinstance(n, nnx.Variable),
588589
)
589-
target_for_restore = _adjust_target_for_moe_fusion(target_for_restore, metadata.item_metadata.tree, True)
590-
item_to_restore = target_for_restore
591-
base_restore_args = ocp.checkpoint_utils.construct_restore_args(target_for_restore)
590+
has_base_key = "base" in metadata.item_metadata.tree
591+
meta_tree_for_params = metadata.item_metadata.tree.get("base", metadata.item_metadata.tree)
592+
target_for_restore = _adjust_target_for_moe_fusion(target_for_restore, meta_tree_for_params, True)
593+
item_to_restore = {"base": target_for_restore} if has_base_key else target_for_restore
592594
restore_args = _fix_restore_args_for_shape_mismatch(
593-
base_restore_args,
594-
metadata.item_metadata.tree,
595-
mesh,
595+
ocp.checkpoint_utils.construct_restore_args(target_for_restore), meta_tree_for_params, mesh
596596
)
597+
restore_args = {"base": restore_args} if has_base_key else restore_args
597598

598599
restored = ckptr.restore(
599600
epath.Path(config.load_parameters_path),
@@ -603,9 +604,10 @@ def _adjust_target_for_moe_fusion(target, meta_tree, is_nnx):
603604
)
604605

605606
if is_nnx_checkpoint:
607+
restored_root = restored["base"] if has_base_key else restored
606608
checkpoint = jax.tree.map(
607609
lambda v: v["value"],
608-
restored,
610+
restored_root,
609611
is_leaf=lambda x: isinstance(x, dict) and "value" in x and not isinstance(x.get("value"), dict),
610612
)
611613
else:
@@ -656,6 +658,13 @@ def _fuse_moe_weights(ckpt_tree, model_arrays_tree):
656658
# This prevents the replicated intermediate copies from persisting until function return.
657659
del restored
658660

661+
def _filter_to_model_keys(ckpt, model):
662+
"""Recursively keep only keys present in model, dropping checkpoint-only fields (e.g. to_nnx__rngs)."""
663+
if not hasattr(ckpt, "items") or not hasattr(model, "items"):
664+
return ckpt
665+
return {k: _filter_to_model_keys(ckpt[k], model[k]) for k in model if k in ckpt}
666+
667+
checkpoint = _filter_to_model_keys(checkpoint, model_arrays)
659668
checkpoint = jax.tree.map(_expand_checkpoint_to_model_shapes, checkpoint, model_arrays)
660669
nnx.update(model, checkpoint)
661670

@@ -672,3 +681,44 @@ def _fuse_moe_weights(ckpt_tree, model_arrays_tree):
672681
return model
673682
else:
674683
return model, mesh
684+
685+
686+
def setup_decode_state_from_nnx(model, config, rng, mesh):
687+
"""Setup decode state by loading an NNX or NNX-RL checkpoint into a linen TrainState.
688+
689+
Calls from_pretrained (which handles NNX and NNX-RL 'base'-nested checkpoints and
690+
applies mesh sharding internally), then extracts nnx.Param values into a plain dict
691+
for the linen TrainState. For linen checkpoints, use maxtext_utils.setup_decode_state instead.
692+
693+
Args:
694+
model: the flax linen model to initialize
695+
config: config object
696+
rng: jax.prng key
697+
mesh: jax.devices() mesh
698+
699+
Returns:
700+
state: linen TrainState with params loaded from the NNX checkpoint
701+
state_mesh_annotations: the mesh annotations for the state
702+
"""
703+
init_state_fn = partial(maxtext_utils.init_initial_state, model, None, config, False, rng)
704+
_, state_mesh_annotations, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, False)
705+
706+
# Load the NNX model; from_pretrained handles sharding via jax.jit(out_shardings=...).
707+
nnx_model = from_pretrained(config, mesh=mesh, model_mode=MODEL_MODE_AUTOREGRESSIVE)
708+
709+
# Extract nnx.Param values, converting the State pytree to a plain nested dict.
710+
def _state_to_dict(tree):
711+
if isinstance(tree, nnx.Variable):
712+
return tree.value
713+
if hasattr(tree, "items") and not isinstance(tree, jax.Array):
714+
return {k: _state_to_dict(v) for k, v in tree.items()}
715+
return tree
716+
717+
nnx_param_state = nnx.state(nnx_model, nnx.Param)
718+
raw_params = _state_to_dict(nnx_param_state)
719+
del nnx_model, nnx_param_state # free memory
720+
721+
params = {"params": raw_params}
722+
723+
state = maxtext_utils.init_decode_state(model.apply, params)
724+
return state, state_mesh_annotations

tests/unit/model_creation_utils_test.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@
2626
from jax.sharding import Mesh
2727
from orbax import checkpoint as ocp
2828

29+
from flax.training import train_state
2930
from maxtext.configs import pyconfig
30-
from maxtext.common.common_types import MODEL_MODE_TRAIN, MODEL_MODE_PREFILL
31+
from maxtext.common.common_types import MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_TRAIN, MODEL_MODE_PREFILL
3132
from maxtext.models import models
3233
from maxtext.utils import maxtext_utils
3334
from maxtext.utils import model_creation_utils
@@ -393,5 +394,54 @@ def test_checkpoint_load_error_raises_value_error(self, mock_ocp):
393394
model_creation_utils.from_pretrained(cfg, self.mesh)
394395

395396

397+
class TestSetupDecodeStateFromNnx(unittest.TestCase):
398+
"""Tests for setup_decode_state_from_nnx()."""
399+
400+
def setUp(self):
401+
self.config = _make_config()
402+
self.mesh = _make_mesh(self.config)
403+
self.rng = jax.random.PRNGKey(0)
404+
405+
def test_returns_linen_train_state_and_annotations(self):
406+
"""Should return a linen TrainState whose params mirror the NNX model's nnx.Param values."""
407+
# Build a real (small) NNX model WITHOUT any patch active so from_pretrained
408+
# runs normally and produces concrete jax.Array weights.
409+
real_nnx_model = model_creation_utils.from_pretrained(self.config, mesh=self.mesh)
410+
411+
linen_model = model_creation_utils.from_config(self.config, mesh=self.mesh, rngs=None)
412+
413+
# Now patch from_pretrained so setup_decode_state_from_nnx never touches a checkpoint.
414+
with patch("maxtext.utils.model_creation_utils.from_pretrained", return_value=real_nnx_model) as mock_fp:
415+
state, state_mesh_annotations = model_creation_utils.setup_decode_state_from_nnx(
416+
linen_model, self.config, self.rng, self.mesh
417+
)
418+
419+
# from_pretrained must have been called with the right model_mode.
420+
mock_fp.assert_called_once()
421+
_, call_kwargs = mock_fp.call_args
422+
self.assertEqual(call_kwargs.get("model_mode"), MODEL_MODE_AUTOREGRESSIVE)
423+
424+
# The result should be a linen TrainState.
425+
self.assertIsInstance(state, train_state.TrainState)
426+
427+
# Params must be nested under "params" and be non-empty concrete arrays.
428+
self.assertIn("params", state.params)
429+
param_leaves = jax.tree.leaves(state.params["params"])
430+
self.assertGreater(len(param_leaves), 0)
431+
for leaf in param_leaves:
432+
self.assertIsInstance(leaf, jax.Array)
433+
434+
# The NNX Param values and the extracted linen params must be numerically identical.
435+
nnx_param_state = nnx.state(real_nnx_model, nnx.Param)
436+
nnx_leaves = jax.tree.leaves(nnx_param_state)
437+
linen_leaves = jax.tree.leaves(state.params["params"])
438+
self.assertEqual(len(nnx_leaves), len(linen_leaves))
439+
for nnx_val, linen_val in zip(nnx_leaves, linen_leaves):
440+
self.assertTrue(jnp.all(nnx_val == linen_val))
441+
442+
# state_mesh_annotations must be returned (non-None).
443+
self.assertIsNotNone(state_mesh_annotations)
444+
445+
396446
if __name__ == "__main__":
397447
unittest.main()

tests/utils/forward_pass_logit_checker.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from maxtext.models import models
5454
from maxtext.utils import max_logging
5555
from maxtext.utils import maxtext_utils
56+
from maxtext.utils import model_creation_utils
5657
import numpy as np
5758
import torch
5859
import torch.nn.functional as F
@@ -447,7 +448,10 @@ def main(config, test_args): # pylint: disable=W0621
447448
else:
448449
maxtext_model = models.transformer_as_linen(config, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
449450
init_state_fn = functools.partial(maxtext_utils.init_initial_state, maxtext_model, None, config, False, rng1)
450-
maxtext_state, _ = maxtext_utils.setup_decode_state(config, mesh, None, init_state_fn)
451+
if test_args.ckpt_type == "linen":
452+
maxtext_state, _ = maxtext_utils.setup_decode_state(config, mesh, None, init_state_fn)
453+
else:
454+
maxtext_state, _ = model_creation_utils.setup_decode_state_from_nnx(maxtext_model, config, rng1, mesh)
451455

452456
prompts = ["I love to", "Today is a", "What is the"]
453457
all_data_to_save = []
@@ -554,6 +558,14 @@ def main(config, test_args): # pylint: disable=W0621
554558
default=False,
555559
help="Skip the first token during comparison to ignore BOS/init mismatches.",
556560
)
561+
parser.add_argument(
562+
"--ckpt_type",
563+
type=str,
564+
required=False,
565+
default="linen",
566+
choices=["linen", "nnx"],
567+
help="Checkpoint format to load: 'linen' (default) or 'nnx'.",
568+
)
557569

558570
# Parse known args returns the namespace AND the list of remaining arguments
559571
test_args, remaining_args = parser.parse_known_args()

0 commit comments

Comments
 (0)