Skip to content

Latest commit

 

History

History
240 lines (177 loc) · 14.9 KB

File metadata and controls

240 lines (177 loc) · 14.9 KB

Checkpoint conversion utilities

This guide provides instructions for using the scripts that convert model checkpoints bidirectionally between Hugging Face and MaxText formats.

Supported models

The following models are supported:

Model Family Sizes HF $\to$ Orbax (scan) HF $\to$ Orbax (unscan) Orbax (scan) $\to$ HF Orbax (unscan) $\to$ HF
Gemma2 2B, 9B, 27B
Gemma3 (Multimodal) 4B, 12B, 27B
Llama3.1 8B, 70B, 450B
Qwen3 0.6B, 4B, 8B, 14B, 32B
Qwen3 MoE 30B, 235B, 480B
Mixtral 8x7B, 8x22B
GPT-OSS 20B, 120B
DeepSeek3 671B - - -
Qwen3 Next 80B

Prerequisites

  • Hugging Face requires Pytorch.
  • Hugging Face model checkpoints require local disk space.
    • The model files are always downloaded to a disk cache first before being loaded into memory (for more info, please consult Hugging Face docs). The default local storage path for Hugging Face models is $HOME/.cache/huggingface/hub

Hugging Face to MaxText

Use the to_maxtext.py script to convert a Hugging Face model into a MaxText checkpoint. The script will automatically download the specified model from the Hugging Face Hub, perform conversion, and save converted checkpoints to given output directory.

**For a complete example, see the test script at tests/end_to_end/tpu/qwen3/4b/test_qwen3_to_mt.sh and tests/end_to_end/tpu/gemma3/4b/test_gemma3_to_mt.sh.

Usage

First, make sure python3 virtual environment for MaxText is set up and enabled.

export VENV_NAME=<your virtual env name> # e.g., maxtext_venv
pip install uv
uv venv --python 3.12 --seed ${VENV_NAME?}
source ${VENV_NAME?}/bin/activate

Second, ensure you have the necessary dependencies installed (e.g., install PyTorch for checkpoint conversion and logit check).

python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu

Third, setup following environment variables for conversion script

# -- Model configuration --
export HF_MODEL=<Hugging Face Model to be converted to MaxText> # e.g. 'llama3.1-8b-Instruct'
export HF_TOKEN=<Hugging Face access token> # your token to access gated HF repos

# -- MaxText configuration --
export MODEL_CHECKPOINT_DIRECTORY=<output directory to store output of checking point> # e.g., gs://my-bucket/my-checkpoint-directory

# -- storage and format options
export USE_ZARR3=<Flag to use zarr3> # Set to True to use zarr3 format (recommended for McJAX); set to False for Pathways.
export USE_OCDBT=<Flag to use ocdbt> # Set to True to use OCDBT format (recommended for McJAX); set to False for Pathways.

export LAZY_LOAD_TENSORS=<Flag to lazy load> # True to use lazy load, False to use eager load.

Finally, run below command to complete the conversion

python3 -m maxtext.checkpoint_conversion.to_maxtext maxtext/configs/base.yml \
    model_name=${HF_MODEL?} \
    hf_access_token=${HF_TOKEN?} \
    base_output_directory=${MODEL_CHECKPOINT_DIRECTORY?} \
    scan_layers=True \
    use_multimodal=false \
    hardware=cpu \
    skip_jax_distributed_system=true \
    checkpoint_storage_use_zarr3=${USE_ZARR3?} \
    checkpoint_storage_use_ocdbt=${USE_OCDBT?} \
    --lazy_load_tensors=${LAZY_LOAD_TENSORS?}

Key arguments:

  • model_name: The model identifier, which should be defined in src/maxtext/configs/types.py.
  • scan_layers: Indicates if the output checkpoint is scanned (scan_layers=true) or unscanned (scan_layers=false).
  • use_multimodal: Indicates if multimodality is used, important for Gemma3.
  • hf_access_token: Your Hugging Face token.
  • base_output_directory: The path where the converted Orbax checkpoint will be stored; it can be Googld Cloud Storage (GCS) or local. If not set, the default output directory is Maxtext/tmp.
  • hardware=cpu: run the conversion script on a CPU machine.
  • checkpoint_storage_use_zarr3: Set to True to use zarr3 format (recommended for McJAX); set to False for Pathways.
  • checkpoint_storage_use_ocdbt: Set to True to use OCDBT format (recommended for McJAX); set to False for Pathways.
  • --lazy_load_tensors (optional): If true, loads Hugging Face weights on-demand to minimize RAM usage. When memory is constrained, it is recommended to use the --lazy_load_tensors=true flag to reduce memory usage during conversion. For example, converting a Llama3.1-70B model with --lazy_load_tensors=true uses around 200GB of RAM and completes in ~10 minutes.
  • --hf_model_path (optional): Specifies a local or remote directory containing the model weights. If unspecified, we use the default Hugging Face repository ID (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek.

Above command will download the Hugging Face model to local machine if hf_model_path is unspecified, or reuse the checkpoint in hf_model_path. It will convert the checkpoint to the MaxText format and save it to ${MODEL_CHECKPOINT_DIRECTORY}/0/items.

MaxText to Hugging Face

Use the to_huggingface.py script to convert a MaxText checkpoint into the Hugging Face format. This is useful for sharing your models or integrating them with the Hugging Face ecosystem. **For a complete example, see the test script at tests/end_to_end/tpu/qwen3/4b/test_qwen3_to_hf.sh.

Usage

The following command converts a MaxText checkpoint and saves it locally, to GCS, or uploads it directly to the Hugging Face Hub.

python3 -m maxtext.checkpoint_conversion.to_huggingface src/maxtext/configs/base.yml \
    model_name=<MODEL_NAME> \
    load_parameters_path=<path-to-maxtext-checkpoint> \
    base_output_directory=<path-to-save-converted-checkpoint> \
    scan_layers=false \
    use_multimodal=false \
    hf_access_token=<your-hf-token> \
    weight_dtype=bfloat16

Key arguments:

  • load_parameters_path: The path to the source MaxText Orbax checkpoint (e.g., gs://your-bucket/maxtext-checkpoint/0/items).
  • model_name: The corresponding model name in the MaxText configuration (e.g., qwen3-4b).
  • scan_layers: Indicates if the output checkpoint is scanned (scan_layers=true) or unscanned (scan_layers=false).
  • hf_access_token: Your Hugging Face token.
  • use_multimodal: Indicates if multimodality is used, important for Gemma3.
  • base_output_directory: The path where the converted Orbax checkpoint will be stored; it can be Googld Cloud Storage (GCS), Hugging Face Hub or local. If not set, the default output directory is Maxtext/tmp.
  • weight_dtype: dtype for MaxText weights. It affects the resulting HF weight dtype. Default value is float32. We recommend using bfloat16 to save memory and speed up conversion.

Verifying conversion correctness

To ensure the conversion was successful, you can use the tests/utils/forward_pass_logit_checker.py script. It runs a forward pass on both the original and converted models and compares the output logits to verify conversion. It is used to verify the bidirectional conversion.

Usage

python3 -m tests.utils.forward_pass_logit_checker src/maxtext/configs/base.yml \
    tokenizer_path=<tokenizer> \
    load_parameters_path=<path-to-maxtext-checkpoint> \
    model_name=<MODEL_NAME> \
    scan_layers=false \
    max_prefill_predict_length=4 \
    max_target_length=8 \
    use_multimodal=false \
    --run_hf_model=True \
    --hf_model_path=<path-to-HF-checkpoint> \
    --max_kl_div=0.015

Key arguments:

  • load_parameters_path: The path to the source MaxText Orbax checkpoint (e.g., gs://your-bucket/maxtext-checkpoint/0/items).
  • model_name: The corresponding model name in the MaxText configuration (e.g., qwen3-4b).
  • scan_layers: Indicates if the output checkpoint is scanned (scan_layers=true) or unscanned (scan_layers=false).
  • use_multimodal: Indicates if multimodality is used.
  • --run_hf_model (optional): Indicates if loading Hugging Face model from the hf_model_path. If not set, it will compare the maxtext logits with pre-saved golden logits.
  • --hf_model_path (optional): The path to the Hugging Face checkpoint (if --run_hf_model=True)
  • --golden_logits_path (optional): The pre-saved golden logits. (if --run_hf_model is not set)
  • --max_kl_div: Max KL divergence tolerance during comparisons.

Example successful conversion verification:

Here is part of the output of forward_pass_logit_checker for the gemma2-2b.

--- Prompt: What is the ---

--- MaxText model top 10 tokens ---
| Token ID   | Token                | Score      |
|------------|----------------------|------------|
| 5830       | difference           | 27.2500    |
| 1963       | best                 | 26.6250    |
| 5316       | average              | 26.3750    |
| 2669       | change               | 26.1250    |
| 12070      | percentage           | 26.1250    |
| 1618       | value                | 25.8750    |
| 1546       | most                 | 25.7500    |
| 66202      | molar                | 25.5000    |
| 3051       | total                | 25.5000    |
| 1503       | name                 | 25.3750    |


--- HF model top 10 tokens ---
| Token ID   | Token                | Score      |
|------------|----------------------|------------|
| 5830       | difference           | 27.2500    |
| 1963       | best                 | 26.6250    |
| 5316       | average              | 26.3750    |
| 12070      | percentage           | 26.1250    |
| 2669       | change               | 26.1250    |
| 1618       | value                | 25.8750    |
| 1546       | most                 | 25.7500    |
| 66202      | molar                | 25.5000    |
| 3051       | total                | 25.5000    |
| 6187       | purpose              | 25.3750    |


--- Similarity Metrics of Top Tokens ---
| Metric                         | Value                |
|--------------------------------|----------------------|
| overlap_count                  | 9/10                 |
| jaccard_similarity             | 0.8181818181818182   |
| rank_agreement_percentage      | 70.0                 |


Average KL divergence per token (D_KL(P_golden || Q_model)): 0.000409

Max KL divergence for a single token in the set: 0.003497

Adding support for new models

To extend conversion support to a new model architecture, you must define its specific parameter and configuration mappings. The conversion logic is decoupled, so you only need to modify the mapping files.

  1. Add parameter mappings:
  • In utils/param_mapping.py, add the parameter name mappings(def {MODEL}_MAXTEXT_TO_HF_PARAM_MAPPING). This is the 1-to-1 mappings of parameters names per layer.
  • In utils/param_mapping.py, add the hook_fn logic (def {MODEL}_MAXTEXT_TO_HF_PARAM_HOOK_FN). This is the transformation needed per layer.
  1. Add Hugging Face weights Shape: In utils/hf_shape.py, define the tensor shape of Hugging Face format (def {MODEL}_HF_WEIGHTS_TO_SHAPE). This is used to ensure the tensor shape is matched after to_huggingface conversion.
  2. Register model key: In utils/utils.py, add the new model key in HF_IDS.
  3. Add transformer config: In utils/hf_model_configs.py, add the transformers.Config object, describing the Hugging Face model configuration (defined in src/maxtext/configs/models). Note: This configuration must precisely match the MaxText model's architecture.

Here is an example PR to add support for gemma3 multi-modal model

Debugging tips

If the converted checkpoint can not get loaded and got error like: "type <class 'jax._src.core.ShapeDtypeStruct'> is not a valid JAX type."

  • Potential Cause: The scan_layers flag is set wrong.

If a converted checkpoint loads without errors but produces incorrect output, consider these common issues:

  • Symptom: The model generates garbage or nonsensical tokens.

    • Potential Cause: The query/key/value (Q/K/V) or Out vectors weights were likely reshaped incorrectly during conversion.
  • Symptom: The model generates repetitive text sequences.

    • Potential Cause: The layer normalization parameters may have been converted incorrectly.