1515# pylint: disable=bare-except, consider-using-generator
1616""" Utils that are only interesting for creating a model in MaxText. """
1717
18+ import dataclasses
1819from collections .abc import Sequence
1920from functools import partial
2021from typing import overload
2324from flax import nnx
2425import flax .linen as nn
2526import jax
27+ import jax .numpy as jnp
2628from jax .sharding import AxisType , Mesh
2729from maxtext .configs import pyconfig
2830from maxtext .common .common_types import MODEL_MODE_TRAIN , ShardMode
2931from maxtext .layers import quantizations
3032from maxtext .models import models
33+ from maxtext .utils import max_logging
3134from maxtext .utils import max_utils
3235from maxtext .utils import maxtext_utils
3336from orbax import checkpoint as ocp
3437
38+ try :
39+ from orbax .checkpoint .metadata import ArrayMetadata as _OrbaxArrayMetadata
40+
41+ def _is_orbax_array_metadata (x ):
42+ return isinstance (x , _OrbaxArrayMetadata )
43+
44+ except ImportError :
45+
46+ def _is_orbax_array_metadata (x ):
47+ return hasattr (x , "shape" ) and hasattr (x , "sharding" ) and hasattr (x , "dtype" ) and not isinstance (x , jax .Array )
48+
49+
50+ def _expand_checkpoint_to_model_shapes (ckpt_arr , model_arr ):
51+ """Expand ckpt_arr to model_arr's shape and re-shard to model_arr's sharding.
52+
53+ Used to expand checkpoint KV-head (and similar) arrays that were saved with
54+ fewer heads than the padded model shape requires (e.g. due to TP/EP padding
55+ in adapter.py). Each dimension must divide evenly into the corresponding
56+ model dimension.
57+
58+ Uses jnp.repeat so that each original slice is placed adjacent to its copies.
59+ For GQA with TP, device i needs KV head i//ratio from the original checkpoint,
60+ so the correct layout is e.g. [h0, h0, h1, h1, h2, h2, h3, h3] rather than
61+ [h0, h1, h2, h3, h0, h1, h2, h3].
62+ """
63+ ckpt_shape = ckpt_arr .shape
64+ model_shape = model_arr .shape
65+ if ckpt_shape == model_shape :
66+ return jax .device_put (ckpt_arr , model_arr .sharding )
67+ if len (ckpt_shape ) != len (model_shape ):
68+ raise ValueError (f"Checkpoint and model arrays have different ranks: { ckpt_shape } vs { model_shape } " )
69+ result = ckpt_arr
70+ for axis , (ckpt_dim , model_dim ) in enumerate (zip (ckpt_shape , model_shape )):
71+ if model_dim % ckpt_dim != 0 :
72+ raise ValueError (
73+ f"Model dimension { model_dim } is not evenly divisible by checkpoint dimension { ckpt_dim } ."
74+ f" Full shapes — checkpoint: { ckpt_shape } , model: { model_shape } "
75+ )
76+ if model_dim != ckpt_dim :
77+ result = jnp .repeat (result , model_dim // ckpt_dim , axis = axis )
78+ return jax .device_put (result , model_arr .sharding )
79+
80+
81+ def _fix_restore_args_for_shape_mismatch (restore_args , stored_metadata_tree , mesh ):
82+ """Use replicated sharding for arrays whose checkpoint shape differs from the model shape.
83+
84+ When the model is initialized with padded shapes (e.g. KV heads padded to match
85+ TP size) but the checkpoint was saved with smaller shapes, Orbax will reject the
86+ restore because the provided sharding is incompatible with the stored shape.
87+ For those arrays we switch to a fully-replicated sharding and clear global_shape
88+ so Orbax loads the array as-written. _expand_checkpoint_to_model_shapes then
89+ expands and re-shards the loaded arrays to match the model.
90+
91+ Uses tree_map_with_path so each ArrayRestoreArgs is looked up by path in the
92+ metadata dict — avoids ordering/count mismatches from flattening two trees with
93+ different pytree node types (e.g. nnx.State vs plain dict) independently.
94+ """
95+ replicated = jax .sharding .NamedSharding (mesh , jax .sharding .PartitionSpec ())
96+
97+ def _key_str (key ):
98+ """Extract string name from a JAX path key (DictKey, GetAttrKey, etc.)."""
99+ if hasattr (key , "key" ):
100+ return str (key .key )
101+ if hasattr (key , "attr" ):
102+ return str (key .attr )
103+ return str (key )
104+
105+ def _lookup_stored_meta (path ):
106+ """Navigate stored_metadata_tree using path keys from the restore_args tree."""
107+ node = stored_metadata_tree
108+ for key in path :
109+ name = _key_str (key )
110+ if isinstance (node , dict ) and name in node :
111+ node = node [name ]
112+ else :
113+ return None
114+ return node
115+
116+ mismatched_paths = []
117+
118+ def _fix_one (path , restore_arg ):
119+ if not isinstance (restore_arg , ocp .ArrayRestoreArgs ):
120+ return restore_arg
121+ stored_meta = _lookup_stored_meta (path )
122+ if stored_meta is not None and _is_orbax_array_metadata (stored_meta ):
123+ stored_shape = tuple (stored_meta .shape )
124+ if restore_arg .global_shape is not None and restore_arg .global_shape != stored_shape :
125+ mismatched_paths .append (
126+ f" { '.' .join (_key_str (k ) for k in path )} : stored={ stored_shape } -> model={ restore_arg .global_shape } "
127+ )
128+ return dataclasses .replace (
129+ restore_arg , global_shape = None , shape = None , sharding = replicated , mesh = None , mesh_axes = None
130+ )
131+ return restore_arg
132+
133+ fixed = jax .tree_util .tree_map_with_path (_fix_one , restore_args , is_leaf = lambda x : isinstance (x , ocp .ArrayRestoreArgs ))
134+ if mismatched_paths :
135+ max_logging .log (
136+ f"Checkpoint shape mismatches ({ len (mismatched_paths )} arrays): loading with replicated "
137+ "sharding and expanding to model shape after restore.\n " + "\n " .join (mismatched_paths )
138+ )
139+ return fixed
140+
35141
36142@overload
37143def from_config (
@@ -154,6 +260,7 @@ def create_sharded_state():
154260 with nn .logical_axis_rules (config .logical_axis_rules ):
155261 sharded_state = create_sharded_state ()
156262 model = nnx .merge (graphdef , sharded_state )
263+
157264 # print weights sharding info under debug sharding mode
158265 if config .debug_sharding :
159266 max_utils .print_non_trivial_mesh_axis (model .mesh )
@@ -163,6 +270,7 @@ def create_sharded_state():
163270 mesh = model .mesh ,
164271 logical_annotations = specs ,
165272 )
273+
166274 if config .load_parameters_path :
167275 try :
168276 ckptr = ocp .Checkpointer (
@@ -196,7 +304,16 @@ def create_sharded_state():
196304 )
197305
198306 item_to_restore = {"params" : {"params" : target_for_restore }}
199- restore_args = {"params" : {"params" : ocp .checkpoint_utils .construct_restore_args (target_for_restore )}}
307+ base_restore_args = ocp .checkpoint_utils .construct_restore_args (target_for_restore )
308+ restore_args = {
309+ "params" : {
310+ "params" : _fix_restore_args_for_shape_mismatch (
311+ base_restore_args ,
312+ metadata .item_metadata .tree ["params" ]["params" ],
313+ mesh ,
314+ )
315+ }
316+ }
200317 else :
201318 # structure of nnx checkpoint: {'decoder': {'value': ...}}
202319 target_for_restore = jax .tree .map (
@@ -205,7 +322,12 @@ def create_sharded_state():
205322 is_leaf = lambda n : isinstance (n , nnx .Variable ),
206323 )
207324 item_to_restore = target_for_restore
208- restore_args = ocp .checkpoint_utils .construct_restore_args (target_for_restore )
325+ base_restore_args = ocp .checkpoint_utils .construct_restore_args (target_for_restore )
326+ restore_args = _fix_restore_args_for_shape_mismatch (
327+ base_restore_args ,
328+ metadata .item_metadata .tree ,
329+ mesh ,
330+ )
209331
210332 restored = ckptr .restore (
211333 epath .Path (config .load_parameters_path ),
@@ -224,6 +346,12 @@ def create_sharded_state():
224346 checkpoint = restored ["params" ]["params" ]
225347
226348 if checkpoint :
349+ model_arrays = jax .tree .map (
350+ lambda v : v .value ,
351+ sharded_state ,
352+ is_leaf = lambda n : isinstance (n , nnx .Variable ),
353+ )
354+ checkpoint = jax .tree .map (_expand_checkpoint_to_model_shapes , checkpoint , model_arrays )
227355 nnx .update (model , checkpoint )
228356
229357 except Exception as e :
0 commit comments