Skip to content

Commit a76f8a0

Browse files
authored
Merge branch 'main' into atwigg/add_qwen3_base
2 parents c5b7e56 + 93d1b83 commit a76f8a0

44 files changed

Lines changed: 2282 additions & 292 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/run_jupyter_notebooks.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@ jobs:
6464
6565
# 2. Install MaxText package and all the post training dependencies
6666
uv pip install ${maxtext_wheel}[tpu-post-train] --resolution=lowest
67-
#TODO: @mazumdera: replace this with the following after release
68-
# uv pip install maxtext[tpu-post-train] --resolution=lowest
6967
install_maxtext_tpu_post_train_extra_deps
7068
.venv/bin/python3 -m ipykernel install --user --name maxtext_venv
7169

PREFLIGHT.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ Before you run ML workload on Multihost with GCE or GKE, simply apply `bash pref
77

88
Here is an example for GCE:
99
```
10-
bash preflight.sh PLATFORM=GCE && python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml run_name=${YOUR_JOB_NAME?}
10+
bash preflight.sh PLATFORM=GCE && python3 -m maxtext.trainers.pre_train.train run_name=${YOUR_JOB_NAME?}
1111
```
1212

1313
Here is an example for GKE:
1414
```
15-
bash preflight.sh PLATFORM=GKE && python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml run_name=${YOUR_JOB_NAME?}
15+
bash preflight.sh PLATFORM=GKE && python3 -m maxtext.trainers.pre_train.train run_name=${YOUR_JOB_NAME?}
1616
```
1717

1818
# Optimization 2: Numa binding (You can only apply this to v4 and v5p)
@@ -22,14 +22,14 @@ For GCE,
2222
[preflight.sh](https://github.com/google/maxtext/blob/main/preflight.sh) will help you install `numactl` dependency, so you can use it directly, here is an example:
2323

2424
```
25-
bash preflight.sh PLATFORM=GCE && numactl --membind 0 --cpunodebind=0 python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml run_name=${YOUR_JOB_NAME?}
25+
bash preflight.sh PLATFORM=GCE && numactl --membind 0 --cpunodebind=0 python3 -m maxtext.trainers.pre_train.train run_name=${YOUR_JOB_NAME?}
2626
```
2727

2828
For GKE,
2929
`numactl` should be built into your docker image from [maxtext_tpu_dependencies.Dockerfile](https://github.com/google/maxtext/blob/main/src/dependencies/dockerfiles/maxtext_tpu_dependencies.Dockerfile), so you can use it directly if you built the maxtext docker image. Here is an example
3030

3131
```
32-
bash preflight.sh PLATFORM=GKE && numactl --membind 0 --cpunodebind=0 python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml run_name=${YOUR_JOB_NAME?}
32+
bash preflight.sh PLATFORM=GKE && numactl --membind 0 --cpunodebind=0 python3 -m maxtext.trainers.pre_train.train run_name=${YOUR_JOB_NAME?}
3333
```
3434

3535
1. `numactl`: This is the command-line tool used for controlling NUMA policy for processes or shared memory. It's particularly useful on multi-socket systems where memory locality can impact performance.

docs/guides/checkpointing_solutions/convert_checkpoint.md

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,8 @@ export HF_TOKEN=<Hugging Face access token> # your token to access gated HF repo
5656

5757
# -- MaxText configuration --
5858
export MODEL_CHECKPOINT_DIRECTORY=<output directory to store output of checking point> # e.g., gs://my-bucket/my-checkpoint-directory
59-
6059
# -- storage and format options
61-
export USE_ZARR3=<Flag to use zarr3> # Set to True to use zarr3 format (recommended for McJAX); set to False for Pathways.
62-
export USE_OCDBT=<Flag to use ocdbt> # Set to True to use OCDBT format (recommended for McJAX); set to False for Pathways.
60+
export USE_PATHWAYS=0 # Set to 1 for Pathways, 0 for McJAX.
6361

6462
export LAZY_LOAD_TENSORS=<Flag to lazy load> # True to use lazy load, False to use eager load.
6563
```
@@ -70,29 +68,26 @@ Finally, run below command to complete the conversion
7068
# Optional: If run out of disk space when downloading HuggingFace safetensors,
7169
# customize your "HF_HOME" to redirect the cache to a larger or mounted disk (e.g., on a TPU VM).
7270
# export HF_HOME="/dev/shm/huggingface_tmp"
73-
python3 -m maxtext.checkpoint_conversion.to_maxtext maxtext/configs/base.yml \
71+
python3 -m maxtext.checkpoint_conversion.to_maxtext \
7472
model_name=${MODEL_NAME?} \
7573
hf_access_token=${HF_TOKEN?} \
7674
base_output_directory=${MODEL_CHECKPOINT_DIRECTORY?} \
7775
scan_layers=True \
7876
use_multimodal=false \
7977
hardware=cpu \
8078
skip_jax_distributed_system=true \
81-
checkpoint_storage_use_zarr3=${USE_ZARR3?} \
82-
checkpoint_storage_use_ocdbt=${USE_OCDBT?} \
79+
checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS)) \
80+
checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) \
8381
--lazy_load_tensors=${LAZY_LOAD_TENSORS?}
8482
```
8583

86-
**Key arguments:**
87-
8884
- `model_name`: The model identifier, which should be defined in `src/maxtext/configs/types.py`.
8985
- `scan_layers`: Indicates if the output checkpoint is [scanned](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/reference/core_concepts/checkpoints.md) (scan_layers=true) or unscanned (scan_layers=false).
9086
- `use_multimodal`: Indicates if multimodality is used, important for Gemma3.
9187
- `hf_access_token`: Your Hugging Face token.
9288
- `base_output_directory`: The path where the converted Orbax checkpoint will be stored; it can be Googld Cloud Storage (GCS) or local. If not set, the default output directory is `Maxtext/tmp`.
9389
- `hardware=cpu`: run the conversion script on a CPU machine.
94-
- `checkpoint_storage_use_zarr3`: Set to True to use zarr3 format (recommended for McJAX); set to False for Pathways.
95-
- `checkpoint_storage_use_ocdbt`: Set to True to use OCDBT format (recommended for McJAX); set to False for Pathways.
90+
- `checkpoint_storage_use_zarr3` and `checkpoint_storage_use_ocdbt`: Set to True for McJAX (default, `USE_PATHWAYS=0`); set to False for Pathways (`USE_PATHWAYS=1`). Both are controlled by the `$((1 - USE_PATHWAYS))` calculation in the example above.
9691
- `--lazy_load_tensors` (optional): If `true`, loads Hugging Face weights on-demand to minimize RAM usage. When memory is constrained, it is recommended to use the `--lazy_load_tensors=true` flag to reduce memory usage during conversion. For example, converting a Llama3.1-70B model with `--lazy_load_tensors=true` uses around 200GB of RAM and completes in ~10 minutes.
9792
- `--hf_model_path` (optional): Specifies a local or remote directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/utils.py#L59-L91) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
9893

@@ -108,7 +103,7 @@ Use the `to_huggingface.py` script to convert a MaxText checkpoint into the Hugg
108103
The following command converts a MaxText checkpoint and saves it locally, to GCS, or uploads it directly to the Hugging Face Hub.
109104

110105
```bash
111-
python3 -m maxtext.checkpoint_conversion.to_huggingface src/maxtext/configs/base.yml \
106+
python3 -m maxtext.checkpoint_conversion.to_huggingface \
112107
model_name=<MODEL_NAME> \
113108
load_parameters_path=<path-to-maxtext-checkpoint> \
114109
base_output_directory=<path-to-save-converted-checkpoint> \
@@ -221,7 +216,7 @@ To extend conversion support to a new model architecture, you must define its sp
221216
- In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/param_mapping.py), add the `hook_fn` logic (`def {MODEL}_MAXTEXT_TO_HF_PARAM_HOOK_FN`). This is the transformation needed per layer.
222217

223218
2. **Add Hugging Face weights Shape**: In [`utils/hf_shape.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/hf_shape.py), define the tensor shape of Hugging Face format (`def {MODEL}_HF_WEIGHTS_TO_SHAPE`). This is used to ensure the tensor shape is matched after to_huggingface conversion.
224-
3. **Register model key**: In [`utils/utils.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/utils.py), add the new model key in `HF_IDS`.
219+
3. **Register model key**: In [`utils/utils.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/globals.py), add the new model key in `HF_IDS`.
225220
4. **Add transformer config**: In [`utils/hf_model_configs.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/hf_model_configs.py), add the `transformers.Config` object, describing the Hugging Face model configuration (defined in [`src/maxtext/configs/models`](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/configs/models)). **Note**: This configuration must precisely match the MaxText model's architecture.
226221

227222
Here is an example [PR to add support for gemma3 multi-modal model](https://github.com/AI-Hypercomputer/maxtext/pull/1983)

0 commit comments

Comments
 (0)