Skip to content

Commit 0d4e75b

Browse files
committed
docs: simplify checkpoint storage flags for Pathways workloads
1 parent 8b86c71 commit 0d4e75b

4 files changed

Lines changed: 44 additions & 43 deletions

File tree

docs/guides/checkpointing_solutions/convert_checkpoint.md

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,8 @@ export HF_TOKEN=<Hugging Face access token> # your token to access gated HF repo
5656

5757
# -- MaxText configuration --
5858
export MODEL_CHECKPOINT_DIRECTORY=<output directory to store output of checking point> # e.g., gs://my-bucket/my-checkpoint-directory
59-
6059
# -- storage and format options
61-
export USE_ZARR3=<Flag to use zarr3> # Set to True to use zarr3 format (recommended for McJAX); set to False for Pathways.
62-
export USE_OCDBT=<Flag to use ocdbt> # Set to True to use OCDBT format (recommended for McJAX); set to False for Pathways.
60+
export USE_PATHWAYS=0 # Set to 1 for Pathways, 0 for McJAX.
6361

6462
export LAZY_LOAD_TENSORS=<Flag to lazy load> # True to use lazy load, False to use eager load.
6563
```
@@ -78,21 +76,18 @@ python3 -m maxtext.checkpoint_conversion.to_maxtext \
7876
use_multimodal=false \
7977
hardware=cpu \
8078
skip_jax_distributed_system=true \
81-
checkpoint_storage_use_zarr3=${USE_ZARR3?} \
82-
checkpoint_storage_use_ocdbt=${USE_OCDBT?} \
79+
checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS)) \
80+
checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) \
8381
--lazy_load_tensors=${LAZY_LOAD_TENSORS?}
8482
```
8583

86-
**Key arguments:**
87-
8884
- `model_name`: The model identifier, which should be defined in `src/maxtext/configs/types.py`.
8985
- `scan_layers`: Indicates if the output checkpoint is [scanned](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/reference/core_concepts/checkpoints.md) (scan_layers=true) or unscanned (scan_layers=false).
9086
- `use_multimodal`: Indicates if multimodality is used, important for Gemma3.
9187
- `hf_access_token`: Your Hugging Face token.
9288
- `base_output_directory`: The path where the converted Orbax checkpoint will be stored; it can be Googld Cloud Storage (GCS) or local. If not set, the default output directory is `Maxtext/tmp`.
9389
- `hardware=cpu`: run the conversion script on a CPU machine.
94-
- `checkpoint_storage_use_zarr3`: Set to True to use zarr3 format (recommended for McJAX); set to False for Pathways.
95-
- `checkpoint_storage_use_ocdbt`: Set to True to use OCDBT format (recommended for McJAX); set to False for Pathways.
90+
- `checkpoint_storage_use_zarr3` and `checkpoint_storage_use_ocdbt`: Set to True for McJAX (default, `USE_PATHWAYS=0`); set to False for Pathways (`USE_PATHWAYS=1`). Both are controlled by the `$((1 - USE_PATHWAYS))` calculation in the example above.
9691
- `--lazy_load_tensors` (optional): If `true`, loads Hugging Face weights on-demand to minimize RAM usage. When memory is constrained, it is recommended to use the `--lazy_load_tensors=true` flag to reduce memory usage during conversion. For example, converting a Llama3.1-70B model with `--lazy_load_tensors=true` uses around 200GB of RAM and completes in ~10 minutes.
9792
- `--hf_model_path` (optional): Specifies a local or remote directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/utils.py#L59-L91) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
9893

docs/guides/checkpointing_solutions/gcs_checkpointing.md

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,39 +11,39 @@ bucket.
1111

1212
The system follows a specific order when deciding which checkpoint to load at startup. The first valid condition met is the one executed:
1313

14-
1. **Resume Current Run**: If a checkpoint already exists for the current
15-
`run_name`, the system loads the latest fully-saved checkpoint. This is the
16-
default behavior to ensure minimal state loss when resuming after an
17-
interruption.
18-
2. **Load from Specific Path**: The system checks for a user-specified path.
19-
* If `load_parameters_path` is set, we load a parameter only checkpoint from that path..
20-
* If `load_full_state_path` is set, we load a full state checkpoint from that path.
21-
* **Note**: These two options are mutually exclusive and will cause an error if both are set.
22-
3. **Initialize from Scratch**: We don't load a checkpoint and initialize state instead.
14+
1. **Resume Current Run**: If a checkpoint already exists for the current
15+
`run_name`, the system loads the latest fully-saved checkpoint. This is the
16+
default behavior to ensure minimal state loss when resuming after an
17+
interruption.
18+
2. **Load from Specific Path**: The system checks for a user-specified path.
19+
- If `load_parameters_path` is set, we load a parameter only checkpoint from that path..
20+
- If `load_full_state_path` is set, we load a full state checkpoint from that path.
21+
- **Note**: These two options are mutually exclusive and will cause an error if both are set.
22+
3. **Initialize from Scratch**: We don't load a checkpoint and initialize state instead.
2323

2424
### MaxText configuration
2525

26-
| Flag | Description | Type | Default |
27-
| :--- | :--- | :--- | :--- |
28-
| `enable_checkpointing` | A master switch to enable (`True`) or disable (`False`) saving checkpoints during the training run. | `boolean` | `False` |
29-
| `async_checkpointing` | When set to (`True`), this flag makes checkpoint saving asynchronous. The training step is only blocked for the minimal time needed to capture the model's state, and the actual writing to storage happens in a background thread. This is highly recommended for performance. It's enabled by default. | `boolean` | `True` |
30-
| `checkpoint_period` | The interval, in training steps, for how often a checkpoint is saved. | `integer` | `10000` |
31-
| `enable_single_replica_ckpt_restoring` | If `True`, one replica reads the checkpoint from storage and then broadcasts it to all other replicas. This can significantly speed up restoration on multi-host systems by reducing redundant reads from storage.<br>**Note**: This feature is only compatible with training jobs that utilize a Distributed Data Parallel (DDP) strategy. | `boolean` | `False` |
32-
| `load_parameters_path` | Specifies a path to a checkpoint directory to load a parameter only checkpoint.<br>**Example**: `"gs://my-bucket/my-previous-run/checkpoints/items/1000"` | `string` | `""` (disabled) |
33-
| `load_full_state_path` | Specifies a path to a checkpoint directory to load a full checkpoint including optimizer state and step count from a specific directory.<br>**Example**: `"gs://my-bucket/my-interrupted-run/checkpoints/items/500"` | `string` | `""` (disabled) |
34-
| `lora_input_adapters_path` | Specifies a parent directory containing LoRA (Low-Rank Adaptation) adapters. | `string` | `""` (disabled) |
35-
| `force_unroll` | If `True`, unrolls the loop when generating a parameter-only checkpoint. | `boolean` | `False` |
26+
| Flag | Description | Type | Default |
27+
| :------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | :-------- | :-------------- |
28+
| `enable_checkpointing` | A master switch to enable (`True`) or disable (`False`) saving checkpoints during the training run. | `boolean` | `False` |
29+
| `async_checkpointing` | When set to (`True`), this flag makes checkpoint saving asynchronous. The training step is only blocked for the minimal time needed to capture the model's state, and the actual writing to storage happens in a background thread. This is highly recommended for performance. It's enabled by default. | `boolean` | `True` |
30+
| `checkpoint_period` | The interval, in training steps, for how often a checkpoint is saved. | `integer` | `10000` |
31+
| `enable_single_replica_ckpt_restoring` | If `True`, one replica reads the checkpoint from storage and then broadcasts it to all other replicas. This can significantly speed up restoration on multi-host systems by reducing redundant reads from storage.<br>**Note**: This feature is only compatible with training jobs that utilize a Distributed Data Parallel (DDP) strategy. | `boolean` | `False` |
32+
| `load_parameters_path` | Specifies a path to a checkpoint directory to load a parameter only checkpoint.<br>**Example**: `"gs://my-bucket/my-previous-run/checkpoints/items/1000"` | `string` | `""` (disabled) |
33+
| `load_full_state_path` | Specifies a path to a checkpoint directory to load a full checkpoint including optimizer state and step count from a specific directory.<br>**Example**: `"gs://my-bucket/my-interrupted-run/checkpoints/items/500"` | `string` | `""` (disabled) |
34+
| `lora_input_adapters_path` | Specifies a parent directory containing LoRA (Low-Rank Adaptation) adapters. | `string` | `""` (disabled) |
35+
| `force_unroll` | If `True`, unrolls the loop when generating a parameter-only checkpoint. | `boolean` | `False` |
3636

3737
## Storage and format configuration
3838

3939
These settings control the underlying storage mechanism ([Orbax](https://orbax.readthedocs.io)) for performance and compatibility.
4040

41-
| Flag | Description | Type | Default |
42-
| :--- | :--- | :--- | :--- |
43-
| `checkpoint_storage_target_data_file_size_bytes` | Sets a target file size for Orbax to chunk large arrays into smaller physical files. This can dramatically speed up loading over a network and in distributed environments. | `integer` | `2147483648` (2 GB) |
44-
| `checkpoint_storage_use_ocdbt` | If `True`, uses the TensorStore **OCDBT** (Optionally-Cooperative Distributed B+ Tree)) key-value store as the underlying storage format for checkpointing. | `boolean` | `True` |
45-
| `checkpoint_storage_use_zarr3` | If `True`, uses the Zarr v3 storage format within Orbax, which is optimized for chunked, compressed, N-dimensional arrays. | `boolean` | `True` |
46-
| `checkpoint_storage_concurrent_gb` | Controls the concurrent I/O limit in gigabytes for the checkpointer. Larger models may require increasing this value to avoid I/O bottlenecks. | `integer` | `96` |
47-
| `enable_orbax_v1` | A boolean flag to explicitly enable features and behaviors from Orbax version 1. | `boolean` | `False` |
48-
| `source_checkpoint_layout` | Specifies the format of the checkpoint being **loaded**. This tells the system how to interpret the files at the source path.<br>**Options**: `"orbax"`, `"safetensors"` | `string` | `"orbax"` |
49-
| `checkpoint_conversion_fn` | A user-defined function to process a loaded checkpoint dictionary into a format that the model can understand. This is essential for loading checkpoints from different frameworks or formats (e.g., converting keys from a Hugging Face SafeTensors file). | `function` or `None` | `None` |
41+
| Flag | Description | Type | Default |
42+
| :----------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------- | :------------------ |
43+
| `checkpoint_storage_target_data_file_size_bytes` | Sets a target file size for Orbax to chunk large arrays into smaller physical files. This can dramatically speed up loading over a network and in distributed environments. | `integer` | `2147483648` (2 GB) |
44+
| `checkpoint_storage_use_ocdbt` | If `True`, uses the TensorStore **OCDBT** (Optionally-Cooperative Distributed B+ Tree)) key-value store as the underlying storage format for checkpointing. Set to `0` for Pathways. | `boolean` | `True` |
45+
| `checkpoint_storage_use_zarr3` | If `True`, uses the Zarr v3 storage format within Orbax, which is optimized for chunked, compressed, N-dimensional arrays. Set to `0` for Pathways. | `boolean` | `True` |
46+
| `checkpoint_storage_concurrent_gb` | Controls the concurrent I/O limit in gigabytes for the checkpointer. Larger models may require increasing this value to avoid I/O bottlenecks. | `integer` | `96` |
47+
| `enable_orbax_v1` | A boolean flag to explicitly enable features and behaviors from Orbax version 1. | `boolean` | `False` |
48+
| `source_checkpoint_layout` | Specifies the format of the checkpoint being **loaded**. This tells the system how to interpret the files at the source path.<br>**Options**: `"orbax"`, `"safetensors"` | `string` | `"orbax"` |
49+
| `checkpoint_conversion_fn` | A user-defined function to process a loaded checkpoint dictionary into a format that the model can understand. This is essential for loading checkpoints from different frameworks or formats (e.g., converting keys from a Hugging Face SafeTensors file). | `function` or `None` | `None` |

docs/tutorials/posttraining/sft_on_multi_host.md

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,11 @@ export MODEL_CHECKPOINT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-b
117117

118118
**Note:** Make sure that `MODEL_CHECKPOINT_PATH` has the checkpoints created using the correct storage flags:
119119

120-
- **For SFT with McJAX:** `checkpoint_storage_use_zarr3=True` and `checkpoint_storage_use_ocdbt=True`.
121-
- **For SFT with Pathways:** `checkpoint_storage_use_zarr3=False` and `checkpoint_storage_use_ocdbt=False`.
120+
```
121+
export USE_PATHWAYS=0 # Set to 1 for Pathways, 0 for McJAX.
122+
checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS))
123+
checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS))
124+
```
122125

123126
### Option 2: Converting a Hugging Face checkpoint
124127

@@ -151,6 +154,8 @@ Once the fine-tuning is completed, you can access your model checkpoints at `$OU
151154
### 6.2. SFT with Pathways
152155

153156
```bash
157+
export USE_PATHWAYS=1
158+
154159
xpk workload create-pathways \
155160
--cluster=${CLUSTER_NAME?} \
156161
--project=${PROJECT?} \
@@ -159,7 +164,7 @@ xpk workload create-pathways \
159164
--workload=${WORKLOAD_NAME?} \
160165
--tpu-type=${TPU_TYPE?} \
161166
--num-slices=${TPU_SLICE?} \
162-
--command="JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 python3 -m maxtext.trainers.post_train.sft.train_sft run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL_NAME?} load_parameters_path=${MODEL_CHECKPOINT_PATH?} hf_access_token=${HF_TOKEN?} tokenizer_path=${TOKENIZER_PATH?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False enable_single_controller=True"
167+
--command="JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL_NAME?} load_parameters_path=${MODEL_CHECKPOINT_PATH?} hf_access_token=${HF_TOKEN?} tokenizer_path=${TOKENIZER_PATH?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS)) checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) enable_single_controller=True"
163168
```
164169

165170
Once the fine-tuning is completed, you can access your model checkpoints at `$OUTPUT_PATH/$WORKLOAD_NAME/checkpoints`.

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,14 @@ def get_maxtext_model(config, devices=None):
8585
Load MaxText model with Tunix adapter.
8686
# Note: pass the path to your scanned checkpoint for 'load_parameters_path'.
8787
# To create a scanned checkpoint, you can use /maxtext/src/MaxText/checkpoint_conversion/to_maxtext.py and if
88-
# using Pathways, please set `checkpoint_storage_use_ocdbt=False checkpoint_storage_use_zarr3=False`
88+
# using Pathways, please set `USE_PATHWAYS=1` and use `$((1 - USE_PATHWAYS))` for storage flags:
89+
# export USE_PATHWAYS=1
8990
# python src/MaxText/checkpoint_conversion/to_maxtext.py \
9091
# --model_name="gemma2-2b" \
9192
# --base_output_directory="/path/to/your/output/directory" \
9293
# --scan_layers=True \
93-
# --checkpoint_storage_use_ocdbt=False\
94-
# checkpoint_storage_use_zarr3=False
94+
# --checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) \
95+
# --checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS))
9596
# Please ensure that you pass the full path ending in `/0/items` for load_parameters_path to train_rl.py i.e.,
9697
# load_parameters_path=/path/to/your/output/directory/0/items
9798
"""

0 commit comments

Comments
 (0)