@@ -888,6 +888,51 @@ def save_adapter_files(output_dir, weights, config, found_modules, model_id):
888888 json .dump (adapter_config , f , indent = 4 )
889889
890890
891+ def param_key_parts_from_path (path_tuple ) -> list [str ]:
892+ """Convert a JAX tree path into MaxText dash-joined key segments.
893+
894+ Normalizes two NNX storage artifacts so the result follows the MaxText-Linen
895+ naming convention and matches the param-mapping tables (e.g.
896+ ``params-decoder-layers_0-self_attention-query-kernel``):
897+
898+ * ``nnx.List`` layer stacks flatten to an *integer* path key
899+ (``decoder -> layers -> 0 -> ...``), which Orbax may restore as a numeric
900+ *string* (``"0"``). Either form is folded into the preceding segment as
901+ ``<name>_<idx>`` (``layers_0``), matching Linen's ``layers_0`` name. (A pure
902+ integer key would otherwise raise ``TypeError: sequence item N: expected str
903+ instance, int found`` when joined; a string ``"0"`` would mismatch the
904+ ``layers_0`` mapping.)
905+ * ``nnx.Variable`` leaves flatten with a trailing ``value`` key
906+ (``...-kernel -> value``). That wrapper segment is dropped, since MaxText-Linen
907+ param keys have no such suffix.
908+
909+ Scanned / plain Linen string paths (no integer key, no trailing ``value``) are
910+ returned unchanged.
911+
912+ Args:
913+ path_tuple: A path produced by ``jax.tree_util.tree_flatten_with_path`` or
914+ ``tree_leaves_with_path`` (a sequence of ``DictKey`` / ``SequenceKey`` /
915+ ``GetAttrKey`` / ``FlattenedIndexKey`` entries).
916+
917+ Returns:
918+ The list of string key segments, e.g. ``["decoder", "layers_0", "kernel"]``.
919+ """
920+ parts : list [str ] = []
921+ for entry in path_tuple :
922+ key = getattr (entry , "key" , getattr (entry , "idx" , getattr (entry , "name" , entry )))
923+ # Fold a layer/expert index (an int, or a numeric string after an Orbax
924+ # round-trip) into the preceding segment: ["layers", 0] -> "layers_0".
925+ if (isinstance (key , int ) or (isinstance (key , str ) and key .isdigit ())) and parts :
926+ parts [- 1 ] = f"{ parts [- 1 ]} _{ key } "
927+ else :
928+ parts .append (str (key ))
929+ # Drop the trailing ``value`` segment that NNX adds for each ``nnx.Variable``
930+ # leaf (``...-kernel -> value``); MaxText-Linen param keys have no such wrapper.
931+ if parts and parts [- 1 ] == "value" :
932+ parts .pop ()
933+ return parts
934+
935+
891936def extract_nnx_weights (weights_dict : dict ) -> dict [str , np .ndarray ]:
892937 """Extract weights from NNX checkpoint structure.
893938
@@ -903,13 +948,10 @@ def extract_nnx_weights(weights_dict: dict) -> dict[str, np.ndarray]:
903948 result = {}
904949 leaves_with_paths = jax .tree_util .tree_leaves_with_path (weights_dict )
905950 for path_tuple , leaf_value in leaves_with_paths :
906- path_keys = [ k . key for k in path_tuple ]
951+ path_keys = param_key_parts_from_path ( path_tuple )
907952 # Skip NNX RNG state variables (not model weights)
908953 if "to_nnx__rngs" in path_keys or any (k .endswith ("_rngs" ) for k in path_keys ):
909954 continue
910- # Skip if this is the "value" key itself - we want the parent path
911- if path_keys [- 1 ] == "value" :
912- path_keys = path_keys [:- 1 ]
913955 maxtext_param_key = "params-" + "-" .join (path_keys )
914956 if not isinstance (leaf_value , (jax .Array , np .ndarray )):
915957 raise ValueError (f"Leaf value for { maxtext_param_key } is not an array. Type: { type (leaf_value )} ." )
@@ -932,8 +974,7 @@ def extract_linen_weights(weights_dict: dict) -> dict[str, np.ndarray]:
932974 result = {}
933975 leaves_with_paths = jax .tree_util .tree_leaves_with_path (weights_dict )
934976 for path_tuple , leaf_value in leaves_with_paths :
935- path_keys = [k .key for k in path_tuple ]
936- # Construct maxtext_param_key from path_tuple
977+ path_keys = param_key_parts_from_path (path_tuple )
937978 maxtext_param_key = "params-" + "-" .join (path_keys )
938979 if not isinstance (leaf_value , (jax .Array , np .ndarray )):
939980 raise ValueError (f"Leaf value for { maxtext_param_key } is not an array. Type: { type (leaf_value )} ." )
0 commit comments