This guide provides instructions for using the scripts that convert model checkpoints bidirectionally between Hugging Face and MaxText formats.
The following models are supported:
| Model Family | Sizes | HF |
HF |
Orbax (scan) |
Orbax (unscan) |
|---|---|---|---|---|---|
| Gemma2 | 2B, 9B, 27B | √ | √ | √ | √ |
| Gemma3 (Multimodal) | 4B, 12B, 27B | √ | √ | √ | √ |
| Llama3.1 | 8B, 70B, 450B | √ | √ | √ | √ |
| Qwen3 | 0.6B, 4B, 8B, 14B, 32B | √ | √ | √ | √ |
| Qwen3 MoE | 30B, 235B, 480B | √ | √ | √ | √ |
| Mixtral | 8x7B, 8x22B | √ | √ | √ | √ |
| GPT-OSS | 20B, 120B | √ | √ | √ | √ |
| DeepSeek3 | 671B | - | - | √ | - |
| Qwen3 Next | 80B | √ | √ | √ | √ |
- Hugging Face requires Pytorch.
- Hugging Face model checkpoints require local disk space.
- The model files are always downloaded to a disk cache first before being loaded into memory (for more info, please consult Hugging Face docs). The default local storage path for Hugging Face models is
$HOME/.cache/huggingface/hub
- The model files are always downloaded to a disk cache first before being loaded into memory (for more info, please consult Hugging Face docs). The default local storage path for Hugging Face models is
Use the to_maxtext.py script to convert a Hugging Face model into a MaxText checkpoint. The script will automatically download the specified model from the Hugging Face Hub, perform conversion, and save converted checkpoints to given output directory.
**For a complete example, see the test script at tests/end_to_end/tpu/qwen3/4b/test_qwen3_to_mt.sh and tests/end_to_end/tpu/gemma3/4b/test_gemma3_to_mt.sh.
First, make sure python3 virtual environment for MaxText is set up and enabled.
export VENV_NAME=<your virtual env name> # e.g., maxtext_venv
pip install uv
uv venv --python 3.12 --seed ${VENV_NAME?}
source ${VENV_NAME?}/bin/activateSecond, ensure you have the necessary dependencies installed (e.g., install PyTorch for checkpoint conversion and logit check).
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpuThird, setup following environment variables for conversion script
# -- Model configuration --
export HF_MODEL=<Hugging Face Model to be converted to MaxText> # e.g. 'llama3.1-8b-Instruct'
export HF_TOKEN=<Hugging Face access token> # your token to access gated HF repos
# -- MaxText configuration --
export MODEL_CHECKPOINT_DIRECTORY=<output directory to store output of checking point> # e.g., gs://my-bucket/my-checkpoint-directory
# -- storage and format options
export USE_ZARR3=<Flag to use zarr3> # Set to True to use zarr3 format (recommended for McJAX); set to False for Pathways.
export USE_OCDBT=<Flag to use ocdbt> # Set to True to use OCDBT format (recommended for McJAX); set to False for Pathways.
export LAZY_LOAD_TENSORS=<Flag to lazy load> # True to use lazy load, False to use eager load.Finally, run below command to complete the conversion
python3 -m maxtext.checkpoint_conversion.to_maxtext maxtext/configs/base.yml \
model_name=${HF_MODEL?} \
hf_access_token=${HF_TOKEN?} \
base_output_directory=${MODEL_CHECKPOINT_DIRECTORY?} \
scan_layers=True \
use_multimodal=false \
hardware=cpu \
skip_jax_distributed_system=true \
checkpoint_storage_use_zarr3=${USE_ZARR3?} \
checkpoint_storage_use_ocdbt=${USE_OCDBT?} \
--lazy_load_tensors=${LAZY_LOAD_TENSORS?}Key arguments:
model_name: The model identifier, which should be defined insrc/maxtext/configs/types.py.scan_layers: Indicates if the output checkpoint is scanned (scan_layers=true) or unscanned (scan_layers=false).use_multimodal: Indicates if multimodality is used, important for Gemma3.hf_access_token: Your Hugging Face token.base_output_directory: The path where the converted Orbax checkpoint will be stored; it can be Googld Cloud Storage (GCS) or local. If not set, the default output directory isMaxtext/tmp.hardware=cpu: run the conversion script on a CPU machine.checkpoint_storage_use_zarr3: Set to True to use zarr3 format (recommended for McJAX); set to False for Pathways.checkpoint_storage_use_ocdbt: Set to True to use OCDBT format (recommended for McJAX); set to False for Pathways.--lazy_load_tensors(optional): Iftrue, loads Hugging Face weights on-demand to minimize RAM usage. When memory is constrained, it is recommended to use the--lazy_load_tensors=trueflag to reduce memory usage during conversion. For example, converting a Llama3.1-70B model with--lazy_load_tensors=trueuses around 200GB of RAM and completes in ~10 minutes.--hf_model_path(optional): Specifies a local or remote directory containing the model weights. If unspecified, we use the default Hugging Face repository ID (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
Above command will download the Hugging Face model to local machine if hf_model_path is unspecified, or reuse the checkpoint in hf_model_path. It will convert the checkpoint to the MaxText format and save it to ${MODEL_CHECKPOINT_DIRECTORY}/0/items.
Use the to_huggingface.py script to convert a MaxText checkpoint into the Hugging Face format. This is useful for sharing your models or integrating them with the Hugging Face ecosystem.
**For a complete example, see the test script at tests/end_to_end/tpu/qwen3/4b/test_qwen3_to_hf.sh.
The following command converts a MaxText checkpoint and saves it locally, to GCS, or uploads it directly to the Hugging Face Hub.
python3 -m maxtext.checkpoint_conversion.to_huggingface src/maxtext/configs/base.yml \
model_name=<MODEL_NAME> \
load_parameters_path=<path-to-maxtext-checkpoint> \
base_output_directory=<path-to-save-converted-checkpoint> \
scan_layers=false \
use_multimodal=false \
hf_access_token=<your-hf-token> \
weight_dtype=bfloat16Key arguments:
load_parameters_path: The path to the source MaxText Orbax checkpoint (e.g.,gs://your-bucket/maxtext-checkpoint/0/items).model_name: The corresponding model name in the MaxText configuration (e.g.,qwen3-4b).scan_layers: Indicates if the output checkpoint is scanned (scan_layers=true) or unscanned (scan_layers=false).hf_access_token: Your Hugging Face token.use_multimodal: Indicates if multimodality is used, important for Gemma3.base_output_directory: The path where the converted Orbax checkpoint will be stored; it can be Googld Cloud Storage (GCS), Hugging Face Hub or local. If not set, the default output directory isMaxtext/tmp.weight_dtype: dtype for MaxText weights. It affects the resulting HF weight dtype. Default value isfloat32. We recommend usingbfloat16to save memory and speed up conversion.
To ensure the conversion was successful, you can use the tests/utils/forward_pass_logit_checker.py script. It runs a forward pass on both the original and converted models and compares the output logits to verify conversion. It is used to verify the bidirectional conversion.
python3 -m tests.utils.forward_pass_logit_checker src/maxtext/configs/base.yml \
tokenizer_path=<tokenizer> \
load_parameters_path=<path-to-maxtext-checkpoint> \
model_name=<MODEL_NAME> \
scan_layers=false \
max_prefill_predict_length=4 \
max_target_length=8 \
use_multimodal=false \
--run_hf_model=True \
--hf_model_path=<path-to-HF-checkpoint> \
--max_kl_div=0.015Key arguments:
load_parameters_path: The path to the source MaxText Orbax checkpoint (e.g.,gs://your-bucket/maxtext-checkpoint/0/items).model_name: The corresponding model name in the MaxText configuration (e.g.,qwen3-4b).scan_layers: Indicates if the output checkpoint is scanned (scan_layers=true) or unscanned (scan_layers=false).use_multimodal: Indicates if multimodality is used.--run_hf_model(optional): Indicates if loading Hugging Face model from the hf_model_path. If not set, it will compare the maxtext logits with pre-saved golden logits.--hf_model_path(optional): The path to the Hugging Face checkpoint (if--run_hf_model=True)--golden_logits_path(optional): The pre-saved golden logits. (if--run_hf_modelis not set)--max_kl_div: Max KL divergence tolerance during comparisons.
Example successful conversion verification:
Here is part of the output of forward_pass_logit_checker for the gemma2-2b.
--- Prompt: What is the ---
--- MaxText model top 10 tokens ---
| Token ID | Token | Score |
|------------|----------------------|------------|
| 5830 | difference | 27.2500 |
| 1963 | best | 26.6250 |
| 5316 | average | 26.3750 |
| 2669 | change | 26.1250 |
| 12070 | percentage | 26.1250 |
| 1618 | value | 25.8750 |
| 1546 | most | 25.7500 |
| 66202 | molar | 25.5000 |
| 3051 | total | 25.5000 |
| 1503 | name | 25.3750 |
--- HF model top 10 tokens ---
| Token ID | Token | Score |
|------------|----------------------|------------|
| 5830 | difference | 27.2500 |
| 1963 | best | 26.6250 |
| 5316 | average | 26.3750 |
| 12070 | percentage | 26.1250 |
| 2669 | change | 26.1250 |
| 1618 | value | 25.8750 |
| 1546 | most | 25.7500 |
| 66202 | molar | 25.5000 |
| 3051 | total | 25.5000 |
| 6187 | purpose | 25.3750 |
--- Similarity Metrics of Top Tokens ---
| Metric | Value |
|--------------------------------|----------------------|
| overlap_count | 9/10 |
| jaccard_similarity | 0.8181818181818182 |
| rank_agreement_percentage | 70.0 |
Average KL divergence per token (D_KL(P_golden || Q_model)): 0.000409
Max KL divergence for a single token in the set: 0.003497
To extend conversion support to a new model architecture, you must define its specific parameter and configuration mappings. The conversion logic is decoupled, so you only need to modify the mapping files.
- Add parameter mappings:
- In
utils/param_mapping.py, add the parameter name mappings(def {MODEL}_MAXTEXT_TO_HF_PARAM_MAPPING). This is the 1-to-1 mappings of parameters names per layer. - In
utils/param_mapping.py, add thehook_fnlogic (def {MODEL}_MAXTEXT_TO_HF_PARAM_HOOK_FN). This is the transformation needed per layer.
- Add Hugging Face weights Shape: In
utils/hf_shape.py, define the tensor shape of Hugging Face format (def {MODEL}_HF_WEIGHTS_TO_SHAPE). This is used to ensure the tensor shape is matched after to_huggingface conversion. - Register model key: In
utils/utils.py, add the new model key inHF_IDS. - Add transformer config: In
utils/hf_model_configs.py, add thetransformers.Configobject, describing the Hugging Face model configuration (defined insrc/maxtext/configs/models). Note: This configuration must precisely match the MaxText model's architecture.
Here is an example PR to add support for gemma3 multi-modal model
If the converted checkpoint can not get loaded and got error like: "type <class 'jax._src.core.ShapeDtypeStruct'> is not a valid JAX type."
- Potential Cause: The scan_layers flag is set wrong.
If a converted checkpoint loads without errors but produces incorrect output, consider these common issues:
-
Symptom: The model generates garbage or nonsensical tokens.
- Potential Cause: The query/key/value (Q/K/V) or Out vectors weights were likely reshaped incorrectly during conversion.
-
Symptom: The model generates repetitive text sequences.
- Potential Cause: The layer normalization parameters may have been converted incorrectly.