Skip to content

Commit 348f69a

Browse files
committed
fix
1 parent d39bb99 commit 348f69a

2 files changed

Lines changed: 11 additions & 7 deletions

File tree

docs/guides/checkpointing_solutions/convert_checkpoint.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ The following models are supported:
1616
| **Qwen3 MoE** | 30B, 235B, 480B |||||
1717
| **Mixtral** | 8x7B, 8x22B |||||
1818
| **GPT-OSS** | 20B, 120B |||||
19-
| **DeepSeek3** | 671B | - | - || - |
19+
| **DeepSeek2** | 16B |||||
20+
| **DeepSeek3** | 671B |||||
21+
| **DeepSeek3.2** | 671B ||| - | - |
2022
| **Qwen3 Next** | 80B |||||
2123

2224
## Prerequisites
@@ -73,7 +75,7 @@ python3 -m maxtext.checkpoint_conversion.to_maxtext \
7375
model_name=${MODEL_NAME?} \
7476
hf_access_token=${HF_TOKEN?} \
7577
base_output_directory=${MODEL_CHECKPOINT_DIRECTORY?} \
76-
scan_layers=True \
78+
scan_layers=true \
7779
use_multimodal=false \
7880
hardware=cpu \
7981
skip_jax_distributed_system=true \
@@ -90,7 +92,7 @@ python3 -m maxtext.checkpoint_conversion.to_maxtext \
9092
- `hardware=cpu`: run the conversion script on a CPU machine.
9193
- `checkpoint_storage_use_zarr3` and `checkpoint_storage_use_ocdbt`: Set to True for McJAX (default, `USE_PATHWAYS=0`); set to False for Pathways (`USE_PATHWAYS=1`). Both are controlled by the `$((1 - USE_PATHWAYS))` calculation in the example above.
9294
- `--lazy_load_tensors` (optional): If `true`, loads Hugging Face weights on-demand to minimize RAM usage. When memory is constrained, it is recommended to use the `--lazy_load_tensors=true` flag to reduce memory usage during conversion. For example, converting a Llama3.1-70B model with `--lazy_load_tensors=true` uses around 200GB of RAM and completes in ~10 minutes.
93-
- `--hf_model_path` (optional): Specifies a local or remote directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/utils.py#L59-L91) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
95+
- `--hf_model_path` (optional): Specifies a local or remote directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/globals.py) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
9496

9597
Above command will download the Hugging Face model to local machine if `hf_model_path` is unspecified, or reuse the checkpoint in `hf_model_path`. It will convert the checkpoint to the MaxText format and save it to `${MODEL_CHECKPOINT_DIRECTORY}/0/items`.
9698

@@ -217,7 +219,7 @@ To extend conversion support to a new model architecture, you must define its sp
217219
- In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/param_mapping.py), add the `hook_fn` logic (`def {MODEL}_MAXTEXT_TO_HF_PARAM_HOOK_FN`). This is the transformation needed per layer.
218220

219221
2. **Add Hugging Face weights Shape**: In [`utils/hf_shape.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/hf_shape.py), define the tensor shape of Hugging Face format (`def {MODEL}_HF_WEIGHTS_TO_SHAPE`). This is used to ensure the tensor shape is matched after to_huggingface conversion.
220-
3. **Register model key**: In [`utils/utils.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/globals.py), add the new model key in `HF_IDS`.
222+
3. **Register model key**: In [`utils/globals.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/globals.py), add the new model key in `HF_IDS`.
221223
4. **Add transformer config**: In [`utils/hf_model_configs.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/hf_model_configs.py), add the `transformers.Config` object, describing the Hugging Face model configuration (defined in [`src/maxtext/configs/models`](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/configs/models)). **Note**: This configuration must precisely match the MaxText model's architecture.
222224

223225
Here is an example [PR to add support for gemma3 multi-modal model](https://github.com/AI-Hypercomputer/maxtext/pull/1983)

src/maxtext/checkpoint_conversion/to_maxtext.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
Defaults to "./mt_output/".
2626
scan_layers: (bool) Whether the MaxText model was trained with scanned layers.
2727
This must match the training configuration of the checkpoint.
28-
lazy_load: (bool) If True, uses an on-demand loading strategy to minimize RAM
28+
--lazy_load: (bool) If True, uses an on-demand loading strategy to minimize RAM
2929
usage during conversion. Recommended if, 2 * model_size (GB) >= system RAM
3030
Defaults to False.
3131
--hf_model_path: (Optional) Specifies a local or remote directory containing the model weights.
@@ -40,7 +40,7 @@
4040
Example Usage:
4141
To convert a gemma2-2b model and save it to a specific directory:
4242
43-
/usr/bin/time -v python src/MaxText/checkpoint_conversion/to_maxtext.py \
43+
python -m maxtext.checkpoint_conversion.to_maxtext \
4444
maxtext/configs/base.yml model_name="gemma2-2b" \
4545
base_output_directory="/path/to/your/output/directory" \
4646
hf_access_token=${HF_TOKEN?} hardware=cpu skip_jax_distributed_system=True \
@@ -51,7 +51,7 @@
5151
5252
To convert a 70B model with minimal RAM usage:
5353
54-
/usr/bin/time -v python src/MaxText/checkpoint_conversion/to_maxtext.py \
54+
python -m maxtext.checkpoint_conversion.to_maxtext \
5555
maxtext/configs/base.yml model_name="llama3.1-70b" \
5656
base_output_directory="gs://my-bucket/maxtext-checkpoints" \
5757
hf_access_token=${HF_TOKEN?} hardware=cpu skip_jax_distributed_system=True \
@@ -601,6 +601,7 @@ def _slicing_loader(base_loader, slice_idx):
601601

602602
def main(
603603
args: Sequence[str],
604+
test_args,
604605
hf_model_path: str | None = None,
605606
revision: str | None = None,
606607
lazy_load_tensors: bool = False,
@@ -905,6 +906,7 @@ def _eager_getter(key):
905906
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={local_args.simulated_cpu_devices_count}"
906907
main(
907908
args=model_args,
909+
test_args=local_args,
908910
hf_model_path=local_args.hf_model_path,
909911
revision=local_args.revision,
910912
lazy_load_tensors=local_args.lazy_load_tensors,

0 commit comments

Comments
 (0)