Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions docs/guides/checkpointing_solutions/convert_checkpoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ python3 -m pip install safetensors --no-deps
export MODEL=<HF_MODEL> # e.g. 'llama3.1-8b-Instruct'
export BASE_OUTPUT_DIRECTORY=<CKPT_PATH> # e.g., gs://my-bucket/my-checkpoint-directory
export USE_PATHWAYS=0 # Set to 1 if you intend to use Pathways for training, 0 for McJAX
export LAZY_LOAD_TENSORS=<LAZY_LOAD> # Set to True to save RAM
```

### Run Conversion
Expand All @@ -63,7 +62,6 @@ python3 -m maxtext.checkpoint_conversion.to_maxtext \
skip_jax_distributed_system=true \
checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS)) \
checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) \
--lazy_load_tensors=${LAZY_LOAD_TENSORS?} \
--save_dtype=bfloat16
```

Expand All @@ -77,7 +75,6 @@ You can find your converted checkpoint files under `${BASE_OUTPUT_DIRECTORY}/0/i
- `base_output_directory`: The path where the converted Orbax checkpoint will be stored; it can be Google Cloud Storage (GCS) or local.
- `hardware=cpu`: The conversion script runs on a CPU machine.
- `checkpoint_storage_use_zarr3` and `checkpoint_storage_use_ocdbt`: These storage flags enable McJAX compatibility when set to True (the default). For Pathways, these should be False.
- `--lazy_load_tensors` (Optional): Enables on-demand loading of weights to prevent OOM (Out of Memory) errors. Highly recommended for large models to reduce memory usage during conversion. For example, converting a Llama3.1-70B model with `--lazy_load_tensors=true` uses around 200GB of RAM and completes in ~10 minutes.
- `--hf_model_path` (Optional): Specifies a customized remote directory or local directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/globals.py) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
- `--save_dtype` (Optional): Specifies the data type of saved model weights. Default to `bfloat16` to save memory.

Expand Down
5 changes: 4 additions & 1 deletion docs/tutorials/posttraining/multimodal.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,17 @@ Then use this command to convert an unscanned checkpoint from HuggingFace to Max
```shell
# Your Hugging Face access token. Required to download gated models like Llama.
# You can generate one at https://huggingface.co/settings/tokens.
# We explicitly set lazy_load_tensors to False here as lazy loading of tensors
# is not supported when use_multimodal is True.
export HF_TOKEN=<Hugging Face access token>
export MAXTEXT_CKPT_PATH=<Checkpoint GCS path> # gs://my-bucket/path
python -m maxtext.checkpoint_conversion.to_maxtext \
model_name=gemma3-4b \
hf_access_token=${HF_TOKEN?} \
base_output_directory=${MAXTEXT_CKPT_PATH?} \
use_multimodal=true \
scan_layers=false
scan_layers=false \
--lazy_load_tensors=False
Comment thread
niting marked this conversation as resolved.
```

For the Llama4 model family, we are using a separate checkpoint conversion script (of note, we will gradually migrate all checkpoint conversion scripts to the above consolidated tool soon):
Expand Down
4 changes: 2 additions & 2 deletions src/maxtext/checkpoint_conversion/to_maxtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ def _merged_getter(key):

def main(
args: Sequence[str],
lazy_load_tensors: bool = False,
lazy_load_tensors: bool = True,
eager_load_method: str = "safetensors",
hf_model_path: str | None = None,
revision: str | None = None,
Expand Down Expand Up @@ -1077,7 +1077,7 @@ def _eager_getter(key):
"--lazy_load_tensors",
type=str2bool,
required=False,
default=False,
default=True,
help="Whether to use lazy loading of HF tensors",
)
# Eager load uses `transformers_class.from_pretrained` with auto dtype or `safetensors.safe_open` with pt.
Expand Down
3 changes: 2 additions & 1 deletion tests/end_to_end/tpu/gemma4/Run_Gemma4.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ python3 -m maxtext.checkpoint_conversion.to_maxtext src/maxtext/configs/base.yml
hf_access_token=${HF_TOKEN} \
base_output_directory=${MODEL_BUCKET}/26b/converted/${idx} \
use_multimodal=true \
scan_layers=false
scan_layers=false \
--lazy_load_tensors=False
Comment thread
SurbhiJainUSC marked this conversation as resolved.
```

This will convert the checkpoints and save them to a Google Cloud Storage (GCS) bucket.
Expand Down
1 change: 0 additions & 1 deletion tests/integration/checkpoint_conversion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def test_qwen3_30b_a3b_roundtrip_conversion(self):
"checkpoint_storage_use_ocdbt=False",
"checkpoint_storage_use_zarr3=False",
"--save_dtype=bfloat16",
"--lazy_load_tensors=True",
]
env = os.environ.copy()
env["JAX_PLATFORMS"] = "cpu"
Expand Down
Loading