|
44 | 44 | from MaxText import max_logging |
45 | 45 | import psutil |
46 | 46 |
|
| 47 | +from etils import epath |
| 48 | +import orbax.checkpoint as ocp |
| 49 | + |
47 | 50 |
|
48 | 51 | SAFE_TENSORS_CONFIG_FILE = "config.json" |
49 | 52 | SAFE_TENSORS_WEIGHTS_FILE = "model.safetensors" |
@@ -130,6 +133,7 @@ def validate_and_filter_param_map_keys(param_map_keys, maxtext_state_keys): |
130 | 133 | "maxtext_state_dict must be a subset of flattened param_map" |
131 | 134 | + f"\nparam map\n{param_map_keys}" |
132 | 135 | + f"\nmaxtext:\n{maxtext_state_keys}" |
| 136 | + + f"\nmissing keys:\n{missing_keys}" |
133 | 137 | ) |
134 | 138 |
|
135 | 139 | # 2 Filter: param map may have extra keys |
@@ -749,6 +753,146 @@ def print_ram_usage(stage=""): |
749 | 753 | ) |
750 | 754 |
|
751 | 755 |
|
| 756 | +def load_orbax_checkpoint(config) -> dict: |
| 757 | + """Loads a full Orbax checkpoint from disk with unsharded arrays. |
| 758 | +
|
| 759 | + Args: |
| 760 | + config: MaxText config containing checkpoint storage settings |
| 761 | +
|
| 762 | + Returns: |
| 763 | + Dictionary containing the full checkpoint structure |
| 764 | + """ |
| 765 | + # Create Orbax checkpointer |
| 766 | + ckptr = ocp.Checkpointer( |
| 767 | + ocp.PyTreeCheckpointHandler( |
| 768 | + restore_concurrent_gb=config.checkpoint_storage_concurrent_gb, |
| 769 | + use_ocdbt=config.checkpoint_storage_use_ocdbt, |
| 770 | + use_zarr3=config.checkpoint_storage_use_zarr3, |
| 771 | + ) |
| 772 | + ) |
| 773 | + |
| 774 | + # Get checkpoint metadata |
| 775 | + checkpoint_path = epath.Path(config.load_parameters_path) |
| 776 | + metadata = ckptr.metadata(checkpoint_path) |
| 777 | + |
| 778 | + # Create a mesh with all devices for unsharded restoration |
| 779 | + devices = np.array(jax.devices()).reshape((-1,)) |
| 780 | + single_device_mesh = jax.sharding.Mesh(devices, ("x",)) |
| 781 | + |
| 782 | + def create_restore_args(tree_metadata): |
| 783 | + """Create restore args for unsharded restoration.""" |
| 784 | + if hasattr(tree_metadata, "shape"): |
| 785 | + return ocp.ArrayRestoreArgs(sharding=jax.sharding.NamedSharding(single_device_mesh, jax.sharding.PartitionSpec())) |
| 786 | + elif isinstance(tree_metadata, dict): |
| 787 | + return {k: create_restore_args(v) for k, v in tree_metadata.items()} |
| 788 | + else: |
| 789 | + return None |
| 790 | + |
| 791 | + restore_args = jax.tree_util.tree_map( |
| 792 | + lambda x: create_restore_args(x) if hasattr(x, "shape") else None, |
| 793 | + metadata.item_metadata.tree, |
| 794 | + is_leaf=lambda x: hasattr(x, "shape"), |
| 795 | + ) |
| 796 | + |
| 797 | + # Restore the entire checkpoint |
| 798 | + return ckptr.restore(checkpoint_path, restore_args=restore_args) |
| 799 | + |
| 800 | + |
| 801 | +def extract_nnx_weights(weights_dict: dict) -> dict[str, np.ndarray]: |
| 802 | + """Extract weights from NNX checkpoint structure. |
| 803 | +
|
| 804 | + NNX checkpoints have structure: {'decoder': {'decoder_norm': {'scale': {'value': array}}}} |
| 805 | + This function flattens it to: {'params-decoder-decoder_norm-scale': array} |
| 806 | +
|
| 807 | + Args: |
| 808 | + weights_dict: NNX checkpoint weights dictionary |
| 809 | +
|
| 810 | + Returns: |
| 811 | + Dictionary mapping parameter names to weight arrays |
| 812 | + """ |
| 813 | + result = {} |
| 814 | + leaves_with_paths = jax.tree_util.tree_leaves_with_path(weights_dict) |
| 815 | + for path_tuple, leaf_value in leaves_with_paths: |
| 816 | + path_keys = [k.key for k in path_tuple] |
| 817 | + # Skip NNX RNG state variables (not model weights) |
| 818 | + if "to_nnx__rngs" in path_keys or any(k.endswith("_rngs") for k in path_keys): |
| 819 | + continue |
| 820 | + # Skip if this is the "value" key itself - we want the parent path |
| 821 | + if path_keys[-1] == "value": |
| 822 | + path_keys = path_keys[:-1] |
| 823 | + maxtext_param_key = "params-" + "-".join(path_keys) |
| 824 | + if not isinstance(leaf_value, (jax.Array, np.ndarray)): |
| 825 | + raise ValueError(f"Leaf value for {maxtext_param_key} is not an array. Type: {type(leaf_value)}.") |
| 826 | + result[maxtext_param_key] = leaf_value |
| 827 | + return result |
| 828 | + |
| 829 | + |
| 830 | +def extract_linen_weights(weights_dict: dict) -> dict[str, np.ndarray]: |
| 831 | + """Extract weights from Linen checkpoint structure. |
| 832 | +
|
| 833 | + Linen checkpoints have structure: {'params': {'decoder': {'decoder_norm': {'scale': array}}}} |
| 834 | + This function flattens it to: {'params-decoder-decoder_norm-scale': array} |
| 835 | +
|
| 836 | + Args: |
| 837 | + weights_dict: Linen checkpoint weights dictionary |
| 838 | +
|
| 839 | + Returns: |
| 840 | + Dictionary mapping parameter names to weight arrays |
| 841 | + """ |
| 842 | + result = {} |
| 843 | + leaves_with_paths = jax.tree_util.tree_leaves_with_path(weights_dict) |
| 844 | + for path_tuple, leaf_value in leaves_with_paths: |
| 845 | + path_keys = [k.key for k in path_tuple] |
| 846 | + # Construct maxtext_param_key from path_tuple |
| 847 | + maxtext_param_key = "params-" + "-".join(path_keys) |
| 848 | + if not isinstance(leaf_value, (jax.Array, np.ndarray)): |
| 849 | + raise ValueError(f"Leaf value for {maxtext_param_key} is not an array. Type: {type(leaf_value)}.") |
| 850 | + result[maxtext_param_key] = leaf_value |
| 851 | + return result |
| 852 | + |
| 853 | + |
| 854 | +def detect_and_extract_checkpoint(checkpoint_dict: dict) -> dict[str, np.ndarray]: |
| 855 | + """Detect checkpoint type (Linen vs NNX) and extract weights. |
| 856 | +
|
| 857 | + Handles multiple NNX checkpoint variants: |
| 858 | + - Linen: {'params': {'params': {'decoder': {...}, 'token_embedder': ... {WEIGHT_ARRAY}}}} |
| 859 | + - NNX-SFT: {'decoder': {...}, 'token_embedder': ... {'value': WEIGHT_ARRAY}} |
| 860 | + - NNX-RL: {'base': {'decoder': {...}, 'token_embedder': ... {'value': WEIGHT_ARRAY}}} |
| 861 | +
|
| 862 | + Currently, we align all extracted weights to MaxText-Linen naming convention |
| 863 | + like "params-decoder-decoder_norm-scale". This allows reusing the same param_mapping |
| 864 | + for both Linen and NNX checkpoints. |
| 865 | +
|
| 866 | + Args: |
| 867 | + checkpoint_dict: Raw checkpoint dictionary from Orbax |
| 868 | +
|
| 869 | + Returns: |
| 870 | + Dictionary mapping MaxText parameter names to weight arrays |
| 871 | + """ |
| 872 | + # Detect checkpoint type by structure |
| 873 | + actual_weights_dict = checkpoint_dict.get("params") |
| 874 | + |
| 875 | + if actual_weights_dict is None: |
| 876 | + # NNX checkpoint: structure is directly at the root |
| 877 | + # Check for NNX-RL variant with 'base' wrapper |
| 878 | + if "base" in checkpoint_dict and isinstance(checkpoint_dict["base"], dict): |
| 879 | + # NNX-RL: {'base': {'decoder': ..., 'token_embedder': ...}} |
| 880 | + max_logging.log("Detected NNX-RL checkpoint structure (with 'base' wrapper)") |
| 881 | + return extract_nnx_weights(checkpoint_dict["base"]) |
| 882 | + else: |
| 883 | + # NNX-SFT: {'decoder': ..., 'token_embedder': ...} |
| 884 | + max_logging.log("Detected NNX-SFT checkpoint structure") |
| 885 | + return extract_nnx_weights(checkpoint_dict) |
| 886 | + else: |
| 887 | + # Linen checkpoint: check if there's a nested 'params' key |
| 888 | + if isinstance(actual_weights_dict, dict) and "params" in actual_weights_dict: |
| 889 | + actual_weights_dict = actual_weights_dict["params"] |
| 890 | + max_logging.log("Detected Linen checkpoint structure") |
| 891 | + else: |
| 892 | + max_logging.log("Detected Linen checkpoint structure (single params layer)") |
| 893 | + return extract_linen_weights(actual_weights_dict) |
| 894 | + |
| 895 | + |
752 | 896 | def get_hf_model(model_id: str, token: str): |
753 | 897 | """Loads the HuggingFace model based on model_id (Eager mode only), used in to_maxtext""" |
754 | 898 | if model_id in ["Qwen/Qwen3-Omni-30B-A3B-Instruct"]: |
|
0 commit comments