Skip to content

Commit e92ca84

Browse files
Merge pull request #4132 from CIeNET-International:fix/update-nnx-decoder
PiperOrigin-RevId: 932494136
2 parents 63abe03 + ccaa5c8 commit e92ca84

10 files changed

Lines changed: 905 additions & 207 deletions

File tree

src/maxtext/checkpoint_conversion/to_maxtext.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
from maxtext.common.common_types import MODEL_MODE_TRAIN
6868
from maxtext.checkpoint_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS
6969
from maxtext.checkpoint_conversion.utils.param_mapping import HOOK_FNS, PARAM_MAPPING
70-
from maxtext.checkpoint_conversion.utils.utils import MemoryMonitorTqdm, apply_hook_fns, load_hf_dict_from_transformers, load_hf_dict_from_safetensors, print_peak_memory, print_ram_usage, save_weights_to_checkpoint, validate_and_filter_param_map_keys
70+
from maxtext.checkpoint_conversion.utils.utils import MemoryMonitorTqdm, apply_hook_fns, load_hf_dict_from_transformers, load_hf_dict_from_safetensors, param_key_parts_from_path, print_peak_memory, print_ram_usage, save_weights_to_checkpoint, validate_and_filter_param_map_keys
7171
from maxtext.inference.inference_utils import str2bool
7272
from maxtext.layers import quantizations
7373
from maxtext.models import models
@@ -333,8 +333,7 @@ def get_maxtext_model_info(config):
333333
# preprocess state
334334
maxtext_abstract_dict = {}
335335
for mt_target_idx, (path_tuple, abstract_leaf_value) in enumerate(abstract_params_flat):
336-
key_parts = [k.key for k in path_tuple if hasattr(k, "key")]
337-
mt_param_key = "params-" + "-".join(key_parts)
336+
mt_param_key = "params-" + "-".join(param_key_parts_from_path(path_tuple))
338337
mt_target_shape = abstract_leaf_value.shape
339338
maxtext_abstract_dict[mt_param_key] = (mt_target_idx, mt_target_shape)
340339

src/maxtext/checkpoint_conversion/utils/utils.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,51 @@ def save_adapter_files(output_dir, weights, config, found_modules, model_id):
888888
json.dump(adapter_config, f, indent=4)
889889

890890

891+
def param_key_parts_from_path(path_tuple) -> list[str]:
892+
"""Convert a JAX tree path into MaxText dash-joined key segments.
893+
894+
Normalizes two NNX storage artifacts so the result follows the MaxText-Linen
895+
naming convention and matches the param-mapping tables (e.g.
896+
``params-decoder-layers_0-self_attention-query-kernel``):
897+
898+
* ``nnx.List`` layer stacks flatten to an *integer* path key
899+
(``decoder -> layers -> 0 -> ...``), which Orbax may restore as a numeric
900+
*string* (``"0"``). Either form is folded into the preceding segment as
901+
``<name>_<idx>`` (``layers_0``), matching Linen's ``layers_0`` name. (A pure
902+
integer key would otherwise raise ``TypeError: sequence item N: expected str
903+
instance, int found`` when joined; a string ``"0"`` would mismatch the
904+
``layers_0`` mapping.)
905+
* ``nnx.Variable`` leaves flatten with a trailing ``value`` key
906+
(``...-kernel -> value``). That wrapper segment is dropped, since MaxText-Linen
907+
param keys have no such suffix.
908+
909+
Scanned / plain Linen string paths (no integer key, no trailing ``value``) are
910+
returned unchanged.
911+
912+
Args:
913+
path_tuple: A path produced by ``jax.tree_util.tree_flatten_with_path`` or
914+
``tree_leaves_with_path`` (a sequence of ``DictKey`` / ``SequenceKey`` /
915+
``GetAttrKey`` / ``FlattenedIndexKey`` entries).
916+
917+
Returns:
918+
The list of string key segments, e.g. ``["decoder", "layers_0", "kernel"]``.
919+
"""
920+
parts: list[str] = []
921+
for entry in path_tuple:
922+
key = getattr(entry, "key", getattr(entry, "idx", getattr(entry, "name", entry)))
923+
# Fold a layer/expert index (an int, or a numeric string after an Orbax
924+
# round-trip) into the preceding segment: ["layers", 0] -> "layers_0".
925+
if (isinstance(key, int) or (isinstance(key, str) and key.isdigit())) and parts:
926+
parts[-1] = f"{parts[-1]}_{key}"
927+
else:
928+
parts.append(str(key))
929+
# Drop the trailing ``value`` segment that NNX adds for each ``nnx.Variable``
930+
# leaf (``...-kernel -> value``); MaxText-Linen param keys have no such wrapper.
931+
if parts and parts[-1] == "value":
932+
parts.pop()
933+
return parts
934+
935+
891936
def extract_nnx_weights(weights_dict: dict) -> dict[str, np.ndarray]:
892937
"""Extract weights from NNX checkpoint structure.
893938
@@ -903,13 +948,10 @@ def extract_nnx_weights(weights_dict: dict) -> dict[str, np.ndarray]:
903948
result = {}
904949
leaves_with_paths = jax.tree_util.tree_leaves_with_path(weights_dict)
905950
for path_tuple, leaf_value in leaves_with_paths:
906-
path_keys = [k.key for k in path_tuple]
951+
path_keys = param_key_parts_from_path(path_tuple)
907952
# Skip NNX RNG state variables (not model weights)
908953
if "to_nnx__rngs" in path_keys or any(k.endswith("_rngs") for k in path_keys):
909954
continue
910-
# Skip if this is the "value" key itself - we want the parent path
911-
if path_keys[-1] == "value":
912-
path_keys = path_keys[:-1]
913955
maxtext_param_key = "params-" + "-".join(path_keys)
914956
if not isinstance(leaf_value, (jax.Array, np.ndarray)):
915957
raise ValueError(f"Leaf value for {maxtext_param_key} is not an array. Type: {type(leaf_value)}.")
@@ -932,8 +974,7 @@ def extract_linen_weights(weights_dict: dict) -> dict[str, np.ndarray]:
932974
result = {}
933975
leaves_with_paths = jax.tree_util.tree_leaves_with_path(weights_dict)
934976
for path_tuple, leaf_value in leaves_with_paths:
935-
path_keys = [k.key for k in path_tuple]
936-
# Construct maxtext_param_key from path_tuple
977+
path_keys = param_key_parts_from_path(path_tuple)
937978
maxtext_param_key = "params-" + "-".join(path_keys)
938979
if not isinstance(leaf_value, (jax.Array, np.ndarray)):
939980
raise ValueError(f"Leaf value for {maxtext_param_key} is not an array. Type: {type(leaf_value)}.")

src/maxtext/layers/attentions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ def __init__(
472472
if self.config.attention_sink:
473473
self.sinks = nnx.Param(
474474
default_bias_init(self.rngs.params(), (self.config.num_query_heads,), self.weight_dtype),
475-
sharding=(None,),
475+
out_sharding=(None,),
476476
)
477477
else:
478478
self.sinks = None
@@ -517,14 +517,14 @@ def __init__(
517517
elif self.is_qwen3_hybrid:
518518
self.query_norm = Qwen3NextRMSNorm(
519519
num_features=self.config.head_dim,
520-
eps=self.config.normalization_layer_epsilon,
520+
epsilon=self.config.normalization_layer_epsilon,
521521
dtype=self.config.dtype,
522522
weight_dtype=self.config.weight_dtype,
523523
rngs=self.rngs,
524524
)
525525
self.key_norm = Qwen3NextRMSNorm(
526526
num_features=self.config.head_dim,
527-
eps=self.config.normalization_layer_epsilon,
527+
epsilon=self.config.normalization_layer_epsilon,
528528
dtype=self.config.dtype,
529529
weight_dtype=self.config.weight_dtype,
530530
rngs=self.rngs,

0 commit comments

Comments
 (0)