3131import numpy as np
3232from jax .sharding import Mesh
3333from maxtext .configs import pyconfig
34- from maxtext .common .common_types import MODEL_MODE_TRAIN
34+ from maxtext .common .common_types import MODEL_MODE_AUTOREGRESSIVE , MODEL_MODE_TRAIN
3535from maxtext .layers import quantizations
3636from maxtext .models import models
3737from maxtext .utils import max_logging
@@ -580,20 +580,21 @@ def _adjust_target_for_moe_fusion(target, meta_tree, is_nnx):
580580 }
581581 }
582582 else :
583- # structure of nnx checkpoint: {'decoder': {'value': ...}}
583+ # NNX checkpoint: {'decoder': {'value': ...}}, or NNX-RL with extra 'base' nesting.
584+ # Restore only nnx.Param — RNG variable shapes may differ between checkpoint and model.
584585 target_for_restore = jax .tree .map (
585586 lambda v : {"value" : v .value },
586587 sharded_state ,
587588 is_leaf = lambda n : isinstance (n , nnx .Variable ),
588589 )
589- target_for_restore = _adjust_target_for_moe_fusion (target_for_restore , metadata .item_metadata .tree , True )
590- item_to_restore = target_for_restore
591- base_restore_args = ocp .checkpoint_utils .construct_restore_args (target_for_restore )
590+ has_base_key = "base" in metadata .item_metadata .tree
591+ meta_tree_for_params = metadata .item_metadata .tree .get ("base" , metadata .item_metadata .tree )
592+ target_for_restore = _adjust_target_for_moe_fusion (target_for_restore , meta_tree_for_params , True )
593+ item_to_restore = {"base" : target_for_restore } if has_base_key else target_for_restore
592594 restore_args = _fix_restore_args_for_shape_mismatch (
593- base_restore_args ,
594- metadata .item_metadata .tree ,
595- mesh ,
595+ ocp .checkpoint_utils .construct_restore_args (target_for_restore ), meta_tree_for_params , mesh
596596 )
597+ restore_args = {"base" : restore_args } if has_base_key else restore_args
597598
598599 restored = ckptr .restore (
599600 epath .Path (config .load_parameters_path ),
@@ -603,9 +604,10 @@ def _adjust_target_for_moe_fusion(target, meta_tree, is_nnx):
603604 )
604605
605606 if is_nnx_checkpoint :
607+ restored_root = restored ["base" ] if has_base_key else restored
606608 checkpoint = jax .tree .map (
607609 lambda v : v ["value" ],
608- restored ,
610+ restored_root ,
609611 is_leaf = lambda x : isinstance (x , dict ) and "value" in x and not isinstance (x .get ("value" ), dict ),
610612 )
611613 else :
@@ -656,6 +658,13 @@ def _fuse_moe_weights(ckpt_tree, model_arrays_tree):
656658 # This prevents the replicated intermediate copies from persisting until function return.
657659 del restored
658660
661+ def _filter_to_model_keys (ckpt , model ):
662+ """Recursively keep only keys present in model, dropping checkpoint-only fields (e.g. to_nnx__rngs)."""
663+ if not hasattr (ckpt , "items" ) or not hasattr (model , "items" ):
664+ return ckpt
665+ return {k : _filter_to_model_keys (ckpt [k ], model [k ]) for k in model if k in ckpt }
666+
667+ checkpoint = _filter_to_model_keys (checkpoint , model_arrays )
659668 checkpoint = jax .tree .map (_expand_checkpoint_to_model_shapes , checkpoint , model_arrays )
660669 nnx .update (model , checkpoint )
661670
@@ -672,3 +681,44 @@ def _fuse_moe_weights(ckpt_tree, model_arrays_tree):
672681 return model
673682 else :
674683 return model , mesh
684+
685+
686+ def setup_decode_state_from_nnx (model , config , rng , mesh ):
687+ """Setup decode state by loading an NNX or NNX-RL checkpoint into a linen TrainState.
688+
689+ Calls from_pretrained (which handles NNX and NNX-RL 'base'-nested checkpoints and
690+ applies mesh sharding internally), then extracts nnx.Param values into a plain dict
691+ for the linen TrainState. For linen checkpoints, use maxtext_utils.setup_decode_state instead.
692+
693+ Args:
694+ model: the flax linen model to initialize
695+ config: config object
696+ rng: jax.prng key
697+ mesh: jax.devices() mesh
698+
699+ Returns:
700+ state: linen TrainState with params loaded from the NNX checkpoint
701+ state_mesh_annotations: the mesh annotations for the state
702+ """
703+ init_state_fn = partial (maxtext_utils .init_initial_state , model , None , config , False , rng )
704+ _ , state_mesh_annotations , _ = maxtext_utils .get_abstract_state (config , mesh , init_state_fn , False )
705+
706+ # Load the NNX model; from_pretrained handles sharding via jax.jit(out_shardings=...).
707+ nnx_model = from_pretrained (config , mesh = mesh , model_mode = MODEL_MODE_AUTOREGRESSIVE )
708+
709+ # Extract nnx.Param values, converting the State pytree to a plain nested dict.
710+ def _state_to_dict (tree ):
711+ if isinstance (tree , nnx .Variable ):
712+ return tree .value
713+ if hasattr (tree , "items" ) and not isinstance (tree , jax .Array ):
714+ return {k : _state_to_dict (v ) for k , v in tree .items ()}
715+ return tree
716+
717+ nnx_param_state = nnx .state (nnx_model , nnx .Param )
718+ raw_params = _state_to_dict (nnx_param_state )
719+ del nnx_model , nnx_param_state # free memory
720+
721+ params = {"params" : raw_params }
722+
723+ state = maxtext_utils .init_decode_state (model .apply , params )
724+ return state , state_mesh_annotations
0 commit comments