1515# pylint: disable=bare-except, consider-using-generator
1616""" Utils that are only interesting for creating a model in MaxText. """
1717
18+ import dataclasses
1819from collections .abc import Sequence
1920from functools import partial
2021from typing import overload
2324from flax import nnx
2425import flax .linen as nn
2526import jax
27+ import jax .numpy as jnp
2628from jax .sharding import AxisType , Mesh
2729from maxtext .configs import pyconfig
2830from maxtext .common .common_types import MODEL_MODE_TRAIN , ShardMode
2931from maxtext .layers import quantizations
3032from maxtext .models import models
33+ from maxtext .utils import max_logging
3134from maxtext .utils import max_utils
3235from maxtext .utils import maxtext_utils
3336from orbax import checkpoint as ocp
3437
38+ try :
39+ from orbax .checkpoint .metadata import ArrayMetadata as _OrbaxArrayMetadata
40+
41+ def _is_orbax_array_metadata (x ):
42+ return isinstance (x , _OrbaxArrayMetadata )
43+
44+ except ImportError :
45+
46+ def _is_orbax_array_metadata (x ):
47+ return hasattr (x , "shape" ) and hasattr (x , "sharding" ) and hasattr (x , "dtype" ) and not isinstance (x , jax .Array )
48+
49+
50+ def _expand_checkpoint_to_model_shapes (ckpt_arr , model_arr ):
51+ """Expand ckpt_arr to model_arr's shape and re-shard to model_arr's sharding.
52+
53+ Used to expand checkpoint KV-head (and similar) arrays that were saved with
54+ fewer heads than the padded model shape requires (e.g. due to TP/EP padding
55+ in adapter.py). Each dimension must divide evenly into the corresponding
56+ model dimension.
57+
58+ Uses jnp.repeat so that each original slice is placed adjacent to its copies.
59+ For GQA with TP, device i needs KV head i//ratio from the original checkpoint,
60+ so the correct layout is e.g. [h0, h0, h1, h1, h2, h2, h3, h3] rather than
61+ [h0, h1, h2, h3, h0, h1, h2, h3].
62+ """
63+ ckpt_shape = ckpt_arr .shape
64+ model_shape = model_arr .shape
65+ if ckpt_shape == model_shape :
66+ return jax .device_put (ckpt_arr , model_arr .sharding )
67+ if len (ckpt_shape ) != len (model_shape ):
68+ raise ValueError (
69+ f"Checkpoint and model arrays have different ranks: { ckpt_shape } vs { model_shape } . "
70+ "If the checkpoint was saved with scan_layers=True (stacked layers), convert it to "
71+ "unscanned format before loading with vLLM (vllm.yml sets scan_layers=False)."
72+ )
73+ result = ckpt_arr
74+ for axis , (ckpt_dim , model_dim ) in enumerate (zip (ckpt_shape , model_shape )):
75+ if model_dim % ckpt_dim != 0 :
76+ raise ValueError (
77+ f"Model dimension { model_dim } is not evenly divisible by checkpoint dimension { ckpt_dim } ."
78+ f" Full shapes — checkpoint: { ckpt_shape } , model: { model_shape } "
79+ )
80+ if model_dim != ckpt_dim :
81+ result = jnp .repeat (result , model_dim // ckpt_dim , axis = axis )
82+ return jax .device_put (result , model_arr .sharding )
83+
84+
85+ def _fix_restore_args_for_shape_mismatch (restore_args , stored_metadata_tree , mesh ):
86+ """Use replicated sharding for arrays whose checkpoint shape differs from the model shape.
87+
88+ When the model is initialized with padded shapes (e.g. KV heads padded to match
89+ TP size) but the checkpoint was saved with smaller shapes, Orbax will reject the
90+ restore because the provided sharding is incompatible with the stored shape.
91+ For those arrays we switch to a fully-replicated sharding and clear global_shape
92+ so Orbax loads the array as-written. _expand_checkpoint_to_model_shapes then
93+ expands and re-shards the loaded arrays to match the model.
94+
95+ Uses tree_map_with_path so each ArrayRestoreArgs is looked up by path in the
96+ metadata dict — avoids ordering/count mismatches from flattening two trees with
97+ different pytree node types (e.g. nnx.State vs plain dict) independently.
98+ """
99+ replicated = jax .sharding .NamedSharding (mesh , jax .sharding .PartitionSpec ())
100+
101+ def _key_str (key ):
102+ """Extract string name from a JAX path key (DictKey, GetAttrKey, etc.)."""
103+ if hasattr (key , "key" ):
104+ return str (key .key )
105+ if hasattr (key , "attr" ):
106+ return str (key .attr )
107+ return str (key )
108+
109+ def _lookup_stored_meta (path ):
110+ """Navigate stored_metadata_tree using path keys from the restore_args tree."""
111+ node = stored_metadata_tree
112+ for key in path :
113+ name = _key_str (key )
114+ if isinstance (node , dict ) and name in node :
115+ node = node [name ]
116+ else :
117+ return None
118+ return node
119+
120+ mismatched_paths = []
121+
122+ def _fix_one (path , restore_arg ):
123+ if not isinstance (restore_arg , ocp .ArrayRestoreArgs ):
124+ return restore_arg
125+ stored_meta = _lookup_stored_meta (path )
126+ if stored_meta is not None and _is_orbax_array_metadata (stored_meta ):
127+ stored_shape = tuple (stored_meta .shape )
128+ if (
129+ restore_arg .global_shape is not None
130+ and restore_arg .global_shape != stored_shape
131+ and len (stored_shape ) == len (restore_arg .global_shape )
132+ ):
133+ mismatched_paths .append (
134+ f" { '.' .join (_key_str (k ) for k in path )} : stored={ stored_shape } -> model={ restore_arg .global_shape } "
135+ )
136+ return dataclasses .replace (
137+ restore_arg , global_shape = None , shape = None , sharding = replicated , mesh = None , mesh_axes = None
138+ )
139+ return restore_arg
140+
141+ fixed = jax .tree_util .tree_map_with_path (_fix_one , restore_args , is_leaf = lambda x : isinstance (x , ocp .ArrayRestoreArgs ))
142+ if mismatched_paths :
143+ max_logging .log (
144+ f"Checkpoint shape mismatches ({ len (mismatched_paths )} arrays): loading with replicated "
145+ "sharding and expanding to model shape after restore.\n " + "\n " .join (mismatched_paths )
146+ )
147+ return fixed
148+
35149
36150@overload
37151def from_config (
@@ -154,6 +268,7 @@ def create_sharded_state():
154268 with nn .logical_axis_rules (config .logical_axis_rules ):
155269 sharded_state = create_sharded_state ()
156270 model = nnx .merge (graphdef , sharded_state )
271+
157272 # print weights sharding info under debug sharding mode
158273 if config .debug_sharding :
159274 max_utils .print_non_trivial_mesh_axis (model .mesh )
@@ -163,6 +278,7 @@ def create_sharded_state():
163278 mesh = model .mesh ,
164279 logical_annotations = specs ,
165280 )
281+
166282 if config .load_parameters_path :
167283 try :
168284 ckptr = ocp .Checkpointer (
@@ -196,7 +312,16 @@ def create_sharded_state():
196312 )
197313
198314 item_to_restore = {"params" : {"params" : target_for_restore }}
199- restore_args = {"params" : {"params" : ocp .checkpoint_utils .construct_restore_args (target_for_restore )}}
315+ base_restore_args = ocp .checkpoint_utils .construct_restore_args (target_for_restore )
316+ restore_args = {
317+ "params" : {
318+ "params" : _fix_restore_args_for_shape_mismatch (
319+ base_restore_args ,
320+ metadata .item_metadata .tree ["params" ]["params" ],
321+ mesh ,
322+ )
323+ }
324+ }
200325 else :
201326 # structure of nnx checkpoint: {'decoder': {'value': ...}}
202327 target_for_restore = jax .tree .map (
@@ -205,7 +330,12 @@ def create_sharded_state():
205330 is_leaf = lambda n : isinstance (n , nnx .Variable ),
206331 )
207332 item_to_restore = target_for_restore
208- restore_args = ocp .checkpoint_utils .construct_restore_args (target_for_restore )
333+ base_restore_args = ocp .checkpoint_utils .construct_restore_args (target_for_restore )
334+ restore_args = _fix_restore_args_for_shape_mismatch (
335+ base_restore_args ,
336+ metadata .item_metadata .tree ,
337+ mesh ,
338+ )
209339
210340 restored = ckptr .restore (
211341 epath .Path (config .load_parameters_path ),
@@ -223,7 +353,22 @@ def create_sharded_state():
223353 else :
224354 checkpoint = restored ["params" ]["params" ]
225355
356+ loaded_count = len (jax .tree_util .tree_leaves (checkpoint ))
357+ expected_count = len (jax .tree_util .tree_leaves (target_for_restore ))
358+ if loaded_count < expected_count :
359+ raise ValueError (
360+ f"Checkpoint at '{ config .load_parameters_path } ' loaded only { loaded_count } of { expected_count } "
361+ "expected parameter arrays. This usually means a scanned (stacked-layers) checkpoint was provided "
362+ "where an unscanned checkpoint is required. Please convert the checkpoint to unscanned format first."
363+ )
364+
226365 if checkpoint :
366+ model_arrays = jax .tree .map (
367+ lambda v : v .value ,
368+ sharded_state ,
369+ is_leaf = lambda n : isinstance (n , nnx .Variable ),
370+ )
371+ checkpoint = jax .tree .map (_expand_checkpoint_to_model_shapes , checkpoint , model_arrays )
227372 nnx .update (model , checkpoint )
228373
229374 except Exception as e :
0 commit comments