Skip to content

Commit 2b06b9c

Browse files
Merge pull request #3144 from AI-Hypercomputer:hengtaoguo-reckpt
PiperOrigin-RevId: 871568310
2 parents fa4e1e7 + 9930f1e commit 2b06b9c

92 files changed

Lines changed: 172 additions & 187 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/CODEOWNERS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Model bring-up
44
src/MaxText/assets @parambole @shuningjin @RissyRan @suexu1025 @jiangjy1982 @gobbleturk @bvandermoon @gagika @shralex @richjames0 @NicoGrande
55
src/MaxText/configs/models @parambole @shuningjin @RissyRan @suexu1025 @jiangjy1982 @gobbleturk @bvandermoon @gagika @shralex @richjames0 @NicoGrande @suexu1025 @jesselu-google
6-
src/MaxText/utils/ckpt_conversion @parambole @shuningjin @RissyRan @suexu1025 @jiangjy1982 @gobbleturk @bvandermoon @hengtaoguo @gagika @shralex @richjames0 @NicoGrande
6+
src/maxtext/checkpoint_conversion @parambole @shuningjin @RissyRan @suexu1025 @jiangjy1982 @gobbleturk @bvandermoon @hengtaoguo @gagika @shralex @richjames0 @NicoGrande
77
src/MaxText/layers @parambole @shuningjin @RissyRan @suexu1025 @jiangjy1982 @gobbleturk @bvandermoon @gagika @shralex @richjames0 @NicoGrande @suexu1025 @jesselu-google
88

99
# Features

docs/guides/checkpointing_solutions/convert_checkpoint.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Checkpoint conversion utilities
22

3-
This guide provides instructions for using the [scripts](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/MaxText/utils/ckpt_conversion) that convert model checkpoints bidirectionally between Hugging Face and MaxText formats.
3+
This guide provides instructions for using the [scripts](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/MaxText/checkpoint_conversion) that convert model checkpoints bidirectionally between Hugging Face and MaxText formats.
44

55
## Supported models
66

@@ -66,7 +66,7 @@ export LAZY_LOAD_TENSORS=<Flag to lazy load> # True to use lazy load, False to u
6666
Finally, run below command to complete the conversion
6767

6868
```bash
69-
python3 -m MaxText.utils.ckpt_conversion.to_maxtext maxtext/configs/base.yml \
69+
python3 -m maxtext.checkpoint_conversion.to_maxtext maxtext/configs/base.yml \
7070
model_name=${HF_MODEL} \
7171
hf_access_token=${HF_TOKEN} \
7272
base_output_directory=${MODEL_CHECKPOINT_DIRECTORY} \
@@ -90,7 +90,7 @@ python3 -m MaxText.utils.ckpt_conversion.to_maxtext maxtext/configs/base.yml \
9090
- `checkpoint_storage_use_zarr3`: # Set to True to use zarr3 format (recommended for McJAX); set to False for Pathways.
9191
- `checkpoint_storage_use_ocdbt`: # Set to True to use OCDBT format (recommended for McJAX); set to False for Pathways.
9292
- `--lazy_load_tensors` (optional): If `true`, loads Hugging Face weights on-demand to minimize RAM usage. For large models, it is recommended to use the `--lazy_load_tensors=true` flag to reduce memory usage during conversion. For example, converting a Llama3.1-70B model with `--lazy_load_tensors=true` uses around 200GB of RAM and completes in ~10 minutes.
93-
- `--hf_model_path` (optional): Specifies a local or remote directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/0d909c44391539db4e8cc2a33de9d77a891beb31/src/MaxText/utils/ckpt_conversion/utils/utils.py#L58-L85) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
93+
- `--hf_model_path` (optional): Specifies a local or remote directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/0d909c44391539db4e8cc2a33de9d77a891beb31/src/MaxText/checkpoint_conversion/utils/utils.py#L58-L85) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
9494

9595
Above command will download the Hugging Face model to local machine, convert it to the MaxText format and save it to `${MODEL_CHECKPOINT_DIRECTORY}/0/items`.
9696

@@ -104,7 +104,7 @@ Use the `to_huggingface.py` script to convert a MaxText checkpoint into the Hugg
104104
The following command converts a MaxText checkpoint and saves it locally, to GCS, or uploads it directly to the Hugging Face Hub.
105105

106106
```bash
107-
python3 -m MaxText.utils.ckpt_conversion.to_huggingface src/maxtext/configs/base.yml \
107+
python3 -m maxtext.checkpoint_conversion.to_huggingface src/maxtext/configs/base.yml \
108108
model_name=<MODEL_NAME> \
109109
load_parameters_path=<path-to-maxtext-checkpoint> \
110110
base_output_directory=<path-to-save-converted-checkpoint> \
@@ -212,12 +212,12 @@ To extend conversion support to a new model architecture, you must define its sp
212212

213213
1. **Add parameter mappings**:
214214

215-
- In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py), add the parameter name mappings(`def {MODEL}_MAXTEXT_TO_HF_PARAM_MAPPING`). This is the 1-to-1 mappings of parameters names per layer.
216-
- In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py), add the `hook_fn` logic (`def {MODEL}_MAXTEXT_TO_HF_PARAM_HOOK_FN`). This is the transformation needed per layer.
215+
- In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/param_mapping.py), add the parameter name mappings(`def {MODEL}_MAXTEXT_TO_HF_PARAM_MAPPING`). This is the 1-to-1 mappings of parameters names per layer.
216+
- In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/param_mapping.py), add the `hook_fn` logic (`def {MODEL}_MAXTEXT_TO_HF_PARAM_HOOK_FN`). This is the transformation needed per layer.
217217

218-
2. **Add Hugging Face weights Shape**: In [`utils/hf_shape.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/hf_shape.py), define the tensor shape of Hugging Face format (`def {MODEL}_HF_WEIGHTS_TO_SHAPE`). This is used to ensure the tensor shape is matched after to_huggingface conversion.
219-
3. **Register model key**: In [`utils/utils.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/utils.py), add the new model key in `HF_IDS`.
220-
4. **Add transformer config**: In [`utils/hf_model_configs.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py), add the `transformers.Config` object, describing the Hugging Face model configuration (defined in ['src/maxtext/configs/models'](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/configs/models)). **Note**: This configuration must precisely match the MaxText model's architecture.
218+
2. **Add Hugging Face weights Shape**: In [`utils/hf_shape.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/hf_shape.py), define the tensor shape of Hugging Face format (`def {MODEL}_HF_WEIGHTS_TO_SHAPE`). This is used to ensure the tensor shape is matched after to_huggingface conversion.
219+
3. **Register model key**: In [`utils/utils.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/utils.py), add the new model key in `HF_IDS`.
220+
4. **Add transformer config**: In [`utils/hf_model_configs.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/hf_model_configs.py), add the `transformers.Config` object, describing the Hugging Face model configuration (defined in ['src/maxtext/configs/models'](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/configs/models)). **Note**: This configuration must precisely match the MaxText model's architecture.
221221

222222
Here is an example [PR to add support for gemma3 multi-modal model](https://github.com/AI-Hypercomputer/maxtext/pull/1983)
223223

docs/tutorials/first_run.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ In the same TPU VM where you just installed all the dependencies of MaxText, You
7575

7676
#### Decoding in MaxText via notebook
7777

78-
You can use [demo_decoding.ipynb](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/demo_decoding.ipynb) to try out decoding on MaxText's `Llama3.1-8b` model implementation. In this notebook, we give `"I love to"` as the prompt, and the greedily sampled first output token is `" cook"`. Please remember to provide the path to your `Llama3.1-8b` checkpoint for the `load_parameters_path` argument in the config inside the notebook. You can use [to_maxtext.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/to_maxtext.py) to create a MaxText/Orbax checkpoint from a Huggingface checkpoint.
78+
You can use [demo_decoding.ipynb](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/demo_decoding.ipynb) to try out decoding on MaxText's `Llama3.1-8b` model implementation. In this notebook, we give `"I love to"` as the prompt, and the greedily sampled first output token is `" cook"`. Please remember to provide the path to your `Llama3.1-8b` checkpoint for the `load_parameters_path` argument in the config inside the notebook. You can use [to_maxtext.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/to_maxtext.py) to create a MaxText/Orbax checkpoint from a Huggingface checkpoint.
7979

8080
### Run MaxText on NVIDIA GPUs
8181

docs/tutorials/posttraining/knowledge_distillation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
132132
export PRE_TRAINED_MODEL_CKPT_DIRECTORY=${BASE_DIRECTORY}/llama3.1-8b-ckpt
133133

134134
# Convert to MaxText format
135-
python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/maxtext/configs/base.yml \
135+
python3 -m maxtext.checkpoint_conversion.to_maxtext src/maxtext/configs/base.yml \
136136
model_name=llama3.1-8b \
137137
hf_access_token=${HF_TOKEN} \
138138
base_output_directory=${PRE_TRAINED_MODEL_CKPT_DIRECTORY} \

docs/tutorials/posttraining/multimodal.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ Multimodal Large Language Models (LLMs) extend traditional text-only models by i
2525

2626
## Checkpoint Conversion
2727

28-
Recently we have onboarded a new centralized tool for bidirectional checkpoint conversion between MaxText and HuggingFace ([README](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/README.md)).
28+
Recently we have onboarded a new centralized tool for bidirectional checkpoint conversion between MaxText and HuggingFace ([README](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/README.md)).
2929

3030
Install pytorch:
3131

@@ -38,7 +38,7 @@ Then use this command to convert an unscanned checkpoint from HuggingFace to Max
3838
```shell
3939
export HF_ACCESS_TOKEN=hf_...
4040
export MAXTEXT_CKPT_GCS_PATH=gs://...
41-
python -m MaxText.utils.ckpt_conversion.to_maxtext maxtext/configs/base.yml \
41+
python -m maxtext.checkpoint_conversion.to_maxtext maxtext/configs/base.yml \
4242
model_name=gemma3-4b \
4343
hf_access_token=$HF_ACCESS_TOKEN \
4444
base_output_directory=$MAXTEXT_CKPT_GCS_PATH \
@@ -51,7 +51,7 @@ For the Llama4 model family, we are using a separate checkpoint conversion scrip
5151
```shell
5252
export LOCAL_HF_MODEL_PATH=... # Need to pre-download the safetensors from HuggingFace
5353
export MAXTEXT_CKPT_GCS_PATH=gs://...
54-
python -m MaxText.utils.ckpt_scripts.llama4_ckpt_unscanned \
54+
python -m maxtext.checkpoint_conversion.standalone_scripts.llama4_ckpt_unscanned \
5555
--model-size=llama4-17b-16e \
5656
--huggingface-checkpoint=True \
5757
--base-model-path=$LOCAL_HF_MODEL_PATH \

src/MaxText/integration/tunix/tunix_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from flax import nnx
2828
from MaxText.layers.models import Transformer
2929
from MaxText.integration.tunix.utils import VllmWeightMapping
30-
from MaxText.utils.ckpt_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS # pylint: disable=ungrouped-imports
30+
from maxtext.checkpoint_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS # pylint: disable=ungrouped-imports
3131

3232

3333
class TunixMaxTextAdapter(nnx.Module):

src/MaxText/integration/tunix/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
import re
1818

1919
import MaxText.integration.tunix.weight_mapping as weight_mapping # pylint: disable=consider-using-from-import
20-
from MaxText.utils.ckpt_conversion.utils.param_mapping import PARAM_MAPPING
21-
from MaxText.utils.ckpt_conversion.utils.param_mapping import VLLM_HOOK_FNS
20+
from maxtext.checkpoint_conversion.utils.param_mapping import PARAM_MAPPING
21+
from maxtext.checkpoint_conversion.utils.param_mapping import VLLM_HOOK_FNS
2222

2323
STANDALONE_VLLM_WEIGHT_MAPPING = weight_mapping.StandaloneVllmWeightMapping()
2424

src/MaxText/rl/train_rl.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ def get_maxtext_model(config, devices=None):
8484
"""
8585
Load MaxText model with Tunix adapter.
8686
# Note: pass the path to your scanned checkpoint for 'load_parameters_path'.
87-
# To create a scanned checkpoint, you can use /maxtext/src/MaxText/utils/ckpt_conversion/to_maxtext.py and if
87+
# To create a scanned checkpoint, you can use /maxtext/src/MaxText/checkpoint_conversion/to_maxtext.py and if
8888
# using Pathways, please set `checkpoint_storage_use_ocdbt=False checkpoint_storage_use_zarr3=False`
89-
# python src/MaxText/utils/ckpt_conversion/to_maxtext.py \
89+
# python src/MaxText/checkpoint_conversion/to_maxtext.py \
9090
# --model_name="gemma2-2b" \
9191
# --base_output_directory="/path/to/your/output/directory" \
9292
# --scan_layers=True \
@@ -321,10 +321,7 @@ def _filter_long_prompts(x):
321321
train_dataset = train_dataset[:dataset_size]
322322
train_dataset = train_dataset.repeat(trainer_config.num_epoch)
323323

324-
train_dataset = (
325-
train_dataset.to_iter_dataset()
326-
.batch(trainer_config.batch_size)
327-
)
324+
train_dataset = train_dataset.to_iter_dataset().batch(trainer_config.batch_size)
328325

329326
eval_dataset_name = getattr(trainer_config, "eval_dataset_name", None)
330327
if not eval_dataset_name:
@@ -342,10 +339,7 @@ def _filter_long_prompts(x):
342339
test_dataset = test_dataset.filter(_filter_long_prompts)
343340
test_dataset = test_dataset[: trainer_config.num_test_batches * trainer_config.batch_size]
344341

345-
test_dataset = (
346-
test_dataset.to_iter_dataset()
347-
.batch(trainer_config.batch_size)
348-
)
342+
test_dataset = test_dataset.to_iter_dataset().batch(trainer_config.batch_size)
349343

350344
# Load reference model
351345
max_logging.log("Creating reference model and also meshes for reference and rollout")

src/MaxText/utils/ckpt_conversion/utils/__init__.py

Lines changed: 0 additions & 13 deletions
This file was deleted.
File renamed without changes.

0 commit comments

Comments
 (0)