Skip to content

Commit 4d99240

Browse files
Merge pull request #2946 from AI-Hypercomputer:hengtaoguo-conversion
PiperOrigin-RevId: 856905495
2 parents d6729b0 + 1fe8fd3 commit 4d99240

2 files changed

Lines changed: 153 additions & 30 deletions

File tree

src/MaxText/utils/ckpt_conversion/to_huggingface.py

Lines changed: 9 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,12 @@
5656
from typing import Sequence
5757
import time
5858
from tqdm import tqdm
59-
import numpy as np
6059

6160
from transformers import AutoTokenizer, AutoProcessor
6261

6362
from absl import app
6463

6564
from MaxText import max_utils
66-
from MaxText import maxengine
6765
from MaxText import pyconfig
6866
from MaxText import max_logging
6967
from MaxText.utils.ckpt_conversion.utils.param_mapping import (
@@ -76,6 +74,8 @@
7674
validate_and_filter_param_map_keys,
7775
process_maxtext_param,
7876
save_model_files,
77+
load_orbax_checkpoint,
78+
detect_and_extract_checkpoint,
7979
HF_IDS,
8080
)
8181

@@ -133,14 +133,10 @@ def main(argv: Sequence[str]) -> None:
133133
max_utils.print_system_information()
134134
overall_start = time.time()
135135

136-
# Load Maxtext checkpoint
137-
max_logging.log("\nLoading Orbax checkpoint...")
136+
# Load Maxtext checkpoint using Orbax to get full parameter dict
137+
max_logging.log(f"\nLoading Orbax checkpoint from: {config.load_parameters_path}")
138138
start = time.time()
139-
engine = maxengine.MaxEngine(config)
140-
rng = jax.random.PRNGKey(1234)
141-
rng, rng_load_params = jax.random.split(rng)
142-
# load params from maxengine
143-
loaded_params_from_engine = engine.load_params(rng_load_params)
139+
checkpoint_dict = load_orbax_checkpoint(config)
144140
max_logging.log(f"Elapse for checkpoint load: {(time.time() - start) / 60:.2f} min")
145141

146142
if not config.base_output_directory:
@@ -170,27 +166,10 @@ def main(argv: Sequence[str]) -> None:
170166
shape_map = mappings["shape_mapping"] # HF target shapes
171167
hook_fn_map = mappings["hook_fn_mapping"]
172168

173-
# 4. Transform Weights
174-
# MaxText `engine.load_params()` returns `state.params` (a FrozenDict).
175-
# The actual weights are typically under `state.params['params']`.
176-
actual_weights_dict = loaded_params_from_engine.get("params")
177-
if actual_weights_dict is None:
178-
raise ValueError("Loaded parameters from engine do not contain a 'params' key. Structure might be unexpected.")
179-
leaves_with_paths = jax.tree_util.tree_leaves_with_path(actual_weights_dict)
180-
181-
# Construct maxtext_state_dict: {parameter name: parameter weight}
182-
maxtext_state_dict = {}
183-
for path_tuple, leaf_value in leaves_with_paths:
184-
# Construct maxtext_param_key from path_tuple
185-
maxtext_param_key = "params-" + "-".join(k.key for k in path_tuple)
186-
# Check leaf value is an array
187-
if not isinstance(leaf_value, (jax.Array, np.ndarray)):
188-
raise ValueError(f"Leaf value for {maxtext_param_key} is not an array. Type: {type(leaf_value)}.")
189-
maxtext_state_dict[maxtext_param_key] = leaf_value
190-
191-
# The param_map may contain tuples as keys, which represent N-to-1 mappings from maxtext to huggingface
192-
# Check maxtext_state_dict is a subset of flattened param_map
193-
# Skip extra keys from param_map
169+
# 4. Extract and transform weights for Linen/NNX-SFT/NNX-RL checkpoints
170+
maxtext_state_dict = detect_and_extract_checkpoint(checkpoint_dict)
171+
172+
# Validate that checkpoint keys match the parameter mapping
194173
filtered_map_keys = validate_and_filter_param_map_keys(param_map.keys(), maxtext_state_dict.keys())
195174

196175
# Iterate through the parameter map to transform and collect weights.

src/MaxText/utils/ckpt_conversion/utils/utils.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@
4444
from MaxText import max_logging
4545
import psutil
4646

47+
from etils import epath
48+
import orbax.checkpoint as ocp
49+
4750

4851
SAFE_TENSORS_CONFIG_FILE = "config.json"
4952
SAFE_TENSORS_WEIGHTS_FILE = "model.safetensors"
@@ -130,6 +133,7 @@ def validate_and_filter_param_map_keys(param_map_keys, maxtext_state_keys):
130133
"maxtext_state_dict must be a subset of flattened param_map"
131134
+ f"\nparam map\n{param_map_keys}"
132135
+ f"\nmaxtext:\n{maxtext_state_keys}"
136+
+ f"\nmissing keys:\n{missing_keys}"
133137
)
134138

135139
# 2 Filter: param map may have extra keys
@@ -749,6 +753,146 @@ def print_ram_usage(stage=""):
749753
)
750754

751755

756+
def load_orbax_checkpoint(config) -> dict:
757+
"""Loads a full Orbax checkpoint from disk with unsharded arrays.
758+
759+
Args:
760+
config: MaxText config containing checkpoint storage settings
761+
762+
Returns:
763+
Dictionary containing the full checkpoint structure
764+
"""
765+
# Create Orbax checkpointer
766+
ckptr = ocp.Checkpointer(
767+
ocp.PyTreeCheckpointHandler(
768+
restore_concurrent_gb=config.checkpoint_storage_concurrent_gb,
769+
use_ocdbt=config.checkpoint_storage_use_ocdbt,
770+
use_zarr3=config.checkpoint_storage_use_zarr3,
771+
)
772+
)
773+
774+
# Get checkpoint metadata
775+
checkpoint_path = epath.Path(config.load_parameters_path)
776+
metadata = ckptr.metadata(checkpoint_path)
777+
778+
# Create a mesh with all devices for unsharded restoration
779+
devices = np.array(jax.devices()).reshape((-1,))
780+
single_device_mesh = jax.sharding.Mesh(devices, ("x",))
781+
782+
def create_restore_args(tree_metadata):
783+
"""Create restore args for unsharded restoration."""
784+
if hasattr(tree_metadata, "shape"):
785+
return ocp.ArrayRestoreArgs(sharding=jax.sharding.NamedSharding(single_device_mesh, jax.sharding.PartitionSpec()))
786+
elif isinstance(tree_metadata, dict):
787+
return {k: create_restore_args(v) for k, v in tree_metadata.items()}
788+
else:
789+
return None
790+
791+
restore_args = jax.tree_util.tree_map(
792+
lambda x: create_restore_args(x) if hasattr(x, "shape") else None,
793+
metadata.item_metadata.tree,
794+
is_leaf=lambda x: hasattr(x, "shape"),
795+
)
796+
797+
# Restore the entire checkpoint
798+
return ckptr.restore(checkpoint_path, restore_args=restore_args)
799+
800+
801+
def extract_nnx_weights(weights_dict: dict) -> dict[str, np.ndarray]:
802+
"""Extract weights from NNX checkpoint structure.
803+
804+
NNX checkpoints have structure: {'decoder': {'decoder_norm': {'scale': {'value': array}}}}
805+
This function flattens it to: {'params-decoder-decoder_norm-scale': array}
806+
807+
Args:
808+
weights_dict: NNX checkpoint weights dictionary
809+
810+
Returns:
811+
Dictionary mapping parameter names to weight arrays
812+
"""
813+
result = {}
814+
leaves_with_paths = jax.tree_util.tree_leaves_with_path(weights_dict)
815+
for path_tuple, leaf_value in leaves_with_paths:
816+
path_keys = [k.key for k in path_tuple]
817+
# Skip NNX RNG state variables (not model weights)
818+
if "to_nnx__rngs" in path_keys or any(k.endswith("_rngs") for k in path_keys):
819+
continue
820+
# Skip if this is the "value" key itself - we want the parent path
821+
if path_keys[-1] == "value":
822+
path_keys = path_keys[:-1]
823+
maxtext_param_key = "params-" + "-".join(path_keys)
824+
if not isinstance(leaf_value, (jax.Array, np.ndarray)):
825+
raise ValueError(f"Leaf value for {maxtext_param_key} is not an array. Type: {type(leaf_value)}.")
826+
result[maxtext_param_key] = leaf_value
827+
return result
828+
829+
830+
def extract_linen_weights(weights_dict: dict) -> dict[str, np.ndarray]:
831+
"""Extract weights from Linen checkpoint structure.
832+
833+
Linen checkpoints have structure: {'params': {'decoder': {'decoder_norm': {'scale': array}}}}
834+
This function flattens it to: {'params-decoder-decoder_norm-scale': array}
835+
836+
Args:
837+
weights_dict: Linen checkpoint weights dictionary
838+
839+
Returns:
840+
Dictionary mapping parameter names to weight arrays
841+
"""
842+
result = {}
843+
leaves_with_paths = jax.tree_util.tree_leaves_with_path(weights_dict)
844+
for path_tuple, leaf_value in leaves_with_paths:
845+
path_keys = [k.key for k in path_tuple]
846+
# Construct maxtext_param_key from path_tuple
847+
maxtext_param_key = "params-" + "-".join(path_keys)
848+
if not isinstance(leaf_value, (jax.Array, np.ndarray)):
849+
raise ValueError(f"Leaf value for {maxtext_param_key} is not an array. Type: {type(leaf_value)}.")
850+
result[maxtext_param_key] = leaf_value
851+
return result
852+
853+
854+
def detect_and_extract_checkpoint(checkpoint_dict: dict) -> dict[str, np.ndarray]:
855+
"""Detect checkpoint type (Linen vs NNX) and extract weights.
856+
857+
Handles multiple NNX checkpoint variants:
858+
- Linen: {'params': {'params': {'decoder': {...}, 'token_embedder': ... {WEIGHT_ARRAY}}}}
859+
- NNX-SFT: {'decoder': {...}, 'token_embedder': ... {'value': WEIGHT_ARRAY}}
860+
- NNX-RL: {'base': {'decoder': {...}, 'token_embedder': ... {'value': WEIGHT_ARRAY}}}
861+
862+
Currently, we align all extracted weights to MaxText-Linen naming convention
863+
like "params-decoder-decoder_norm-scale". This allows reusing the same param_mapping
864+
for both Linen and NNX checkpoints.
865+
866+
Args:
867+
checkpoint_dict: Raw checkpoint dictionary from Orbax
868+
869+
Returns:
870+
Dictionary mapping MaxText parameter names to weight arrays
871+
"""
872+
# Detect checkpoint type by structure
873+
actual_weights_dict = checkpoint_dict.get("params")
874+
875+
if actual_weights_dict is None:
876+
# NNX checkpoint: structure is directly at the root
877+
# Check for NNX-RL variant with 'base' wrapper
878+
if "base" in checkpoint_dict and isinstance(checkpoint_dict["base"], dict):
879+
# NNX-RL: {'base': {'decoder': ..., 'token_embedder': ...}}
880+
max_logging.log("Detected NNX-RL checkpoint structure (with 'base' wrapper)")
881+
return extract_nnx_weights(checkpoint_dict["base"])
882+
else:
883+
# NNX-SFT: {'decoder': ..., 'token_embedder': ...}
884+
max_logging.log("Detected NNX-SFT checkpoint structure")
885+
return extract_nnx_weights(checkpoint_dict)
886+
else:
887+
# Linen checkpoint: check if there's a nested 'params' key
888+
if isinstance(actual_weights_dict, dict) and "params" in actual_weights_dict:
889+
actual_weights_dict = actual_weights_dict["params"]
890+
max_logging.log("Detected Linen checkpoint structure")
891+
else:
892+
max_logging.log("Detected Linen checkpoint structure (single params layer)")
893+
return extract_linen_weights(actual_weights_dict)
894+
895+
752896
def get_hf_model(model_id: str, token: str):
753897
"""Loads the HuggingFace model based on model_id (Eager mode only), used in to_maxtext"""
754898
if model_id in ["Qwen/Qwen3-Omni-30B-A3B-Instruct"]:

0 commit comments

Comments
 (0)