Skip to content

Commit 94c8c9d

Browse files
committed
Merge branch 'main' of github.com:AI-Hypercomputer/maxtext into shuningjin-qwix1
2 parents 21c3248 + a3c19fd commit 94c8c9d

11 files changed

Lines changed: 526 additions & 129 deletions

File tree

docs/guides/checkpointing_solutions/convert_checkpoint.md

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,8 @@ export HF_TOKEN=<Hugging Face access token> # your token to access gated HF repo
5656

5757
# -- MaxText configuration --
5858
export MODEL_CHECKPOINT_DIRECTORY=<output directory to store output of checking point> # e.g., gs://my-bucket/my-checkpoint-directory
59-
6059
# -- storage and format options
61-
export USE_ZARR3=<Flag to use zarr3> # Set to True to use zarr3 format (recommended for McJAX); set to False for Pathways.
62-
export USE_OCDBT=<Flag to use ocdbt> # Set to True to use OCDBT format (recommended for McJAX); set to False for Pathways.
60+
export USE_PATHWAYS=0 # Set to 1 for Pathways, 0 for McJAX.
6361

6462
export LAZY_LOAD_TENSORS=<Flag to lazy load> # True to use lazy load, False to use eager load.
6563
```
@@ -78,21 +76,18 @@ python3 -m maxtext.checkpoint_conversion.to_maxtext \
7876
use_multimodal=false \
7977
hardware=cpu \
8078
skip_jax_distributed_system=true \
81-
checkpoint_storage_use_zarr3=${USE_ZARR3?} \
82-
checkpoint_storage_use_ocdbt=${USE_OCDBT?} \
79+
checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS)) \
80+
checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) \
8381
--lazy_load_tensors=${LAZY_LOAD_TENSORS?}
8482
```
8583

86-
**Key arguments:**
87-
8884
- `model_name`: The model identifier, which should be defined in `src/maxtext/configs/types.py`.
8985
- `scan_layers`: Indicates if the output checkpoint is [scanned](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/reference/core_concepts/checkpoints.md) (scan_layers=true) or unscanned (scan_layers=false).
9086
- `use_multimodal`: Indicates if multimodality is used, important for Gemma3.
9187
- `hf_access_token`: Your Hugging Face token.
9288
- `base_output_directory`: The path where the converted Orbax checkpoint will be stored; it can be Googld Cloud Storage (GCS) or local. If not set, the default output directory is `Maxtext/tmp`.
9389
- `hardware=cpu`: run the conversion script on a CPU machine.
94-
- `checkpoint_storage_use_zarr3`: Set to True to use zarr3 format (recommended for McJAX); set to False for Pathways.
95-
- `checkpoint_storage_use_ocdbt`: Set to True to use OCDBT format (recommended for McJAX); set to False for Pathways.
90+
- `checkpoint_storage_use_zarr3` and `checkpoint_storage_use_ocdbt`: Set to True for McJAX (default, `USE_PATHWAYS=0`); set to False for Pathways (`USE_PATHWAYS=1`). Both are controlled by the `$((1 - USE_PATHWAYS))` calculation in the example above.
9691
- `--lazy_load_tensors` (optional): If `true`, loads Hugging Face weights on-demand to minimize RAM usage. When memory is constrained, it is recommended to use the `--lazy_load_tensors=true` flag to reduce memory usage during conversion. For example, converting a Llama3.1-70B model with `--lazy_load_tensors=true` uses around 200GB of RAM and completes in ~10 minutes.
9792
- `--hf_model_path` (optional): Specifies a local or remote directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/utils.py#L59-L91) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
9893

docs/guides/checkpointing_solutions/gcs_checkpointing.md

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,39 +11,39 @@ bucket.
1111

1212
The system follows a specific order when deciding which checkpoint to load at startup. The first valid condition met is the one executed:
1313

14-
1. **Resume Current Run**: If a checkpoint already exists for the current
15-
`run_name`, the system loads the latest fully-saved checkpoint. This is the
16-
default behavior to ensure minimal state loss when resuming after an
17-
interruption.
18-
2. **Load from Specific Path**: The system checks for a user-specified path.
19-
* If `load_parameters_path` is set, we load a parameter only checkpoint from that path..
20-
* If `load_full_state_path` is set, we load a full state checkpoint from that path.
21-
* **Note**: These two options are mutually exclusive and will cause an error if both are set.
22-
3. **Initialize from Scratch**: We don't load a checkpoint and initialize state instead.
14+
1. **Resume Current Run**: If a checkpoint already exists for the current
15+
`run_name`, the system loads the latest fully-saved checkpoint. This is the
16+
default behavior to ensure minimal state loss when resuming after an
17+
interruption.
18+
2. **Load from Specific Path**: The system checks for a user-specified path.
19+
- If `load_parameters_path` is set, we load a parameter only checkpoint from that path..
20+
- If `load_full_state_path` is set, we load a full state checkpoint from that path.
21+
- **Note**: These two options are mutually exclusive and will cause an error if both are set.
22+
3. **Initialize from Scratch**: We don't load a checkpoint and initialize state instead.
2323

2424
### MaxText configuration
2525

26-
| Flag | Description | Type | Default |
27-
| :--- | :--- | :--- | :--- |
28-
| `enable_checkpointing` | A master switch to enable (`True`) or disable (`False`) saving checkpoints during the training run. | `boolean` | `False` |
29-
| `async_checkpointing` | When set to (`True`), this flag makes checkpoint saving asynchronous. The training step is only blocked for the minimal time needed to capture the model's state, and the actual writing to storage happens in a background thread. This is highly recommended for performance. It's enabled by default. | `boolean` | `True` |
30-
| `checkpoint_period` | The interval, in training steps, for how often a checkpoint is saved. | `integer` | `10000` |
31-
| `enable_single_replica_ckpt_restoring` | If `True`, one replica reads the checkpoint from storage and then broadcasts it to all other replicas. This can significantly speed up restoration on multi-host systems by reducing redundant reads from storage.<br>**Note**: This feature is only compatible with training jobs that utilize a Distributed Data Parallel (DDP) strategy. | `boolean` | `False` |
32-
| `load_parameters_path` | Specifies a path to a checkpoint directory to load a parameter only checkpoint.<br>**Example**: `"gs://my-bucket/my-previous-run/checkpoints/items/1000"` | `string` | `""` (disabled) |
33-
| `load_full_state_path` | Specifies a path to a checkpoint directory to load a full checkpoint including optimizer state and step count from a specific directory.<br>**Example**: `"gs://my-bucket/my-interrupted-run/checkpoints/items/500"` | `string` | `""` (disabled) |
34-
| `lora_input_adapters_path` | Specifies a parent directory containing LoRA (Low-Rank Adaptation) adapters. | `string` | `""` (disabled) |
35-
| `force_unroll` | If `True`, unrolls the loop when generating a parameter-only checkpoint. | `boolean` | `False` |
26+
| Flag | Description | Type | Default |
27+
| :------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | :-------- | :-------------- |
28+
| `enable_checkpointing` | A master switch to enable (`True`) or disable (`False`) saving checkpoints during the training run. | `boolean` | `False` |
29+
| `async_checkpointing` | When set to (`True`), this flag makes checkpoint saving asynchronous. The training step is only blocked for the minimal time needed to capture the model's state, and the actual writing to storage happens in a background thread. This is highly recommended for performance. It's enabled by default. | `boolean` | `True` |
30+
| `checkpoint_period` | The interval, in training steps, for how often a checkpoint is saved. | `integer` | `10000` |
31+
| `enable_single_replica_ckpt_restoring` | If `True`, one replica reads the checkpoint from storage and then broadcasts it to all other replicas. This can significantly speed up restoration on multi-host systems by reducing redundant reads from storage.<br>**Note**: This feature is only compatible with training jobs that utilize a Distributed Data Parallel (DDP) strategy. | `boolean` | `False` |
32+
| `load_parameters_path` | Specifies a path to a checkpoint directory to load a parameter only checkpoint.<br>**Example**: `"gs://my-bucket/my-previous-run/checkpoints/items/1000"` | `string` | `""` (disabled) |
33+
| `load_full_state_path` | Specifies a path to a checkpoint directory to load a full checkpoint including optimizer state and step count from a specific directory.<br>**Example**: `"gs://my-bucket/my-interrupted-run/checkpoints/items/500"` | `string` | `""` (disabled) |
34+
| `lora_input_adapters_path` | Specifies a parent directory containing LoRA (Low-Rank Adaptation) adapters. | `string` | `""` (disabled) |
35+
| `force_unroll` | If `True`, unrolls the loop when generating a parameter-only checkpoint. | `boolean` | `False` |
3636

3737
## Storage and format configuration
3838

3939
These settings control the underlying storage mechanism ([Orbax](https://orbax.readthedocs.io)) for performance and compatibility.
4040

41-
| Flag | Description | Type | Default |
42-
| :--- | :--- | :--- | :--- |
43-
| `checkpoint_storage_target_data_file_size_bytes` | Sets a target file size for Orbax to chunk large arrays into smaller physical files. This can dramatically speed up loading over a network and in distributed environments. | `integer` | `2147483648` (2 GB) |
44-
| `checkpoint_storage_use_ocdbt` | If `True`, uses the TensorStore **OCDBT** (Optionally-Cooperative Distributed B+ Tree)) key-value store as the underlying storage format for checkpointing. | `boolean` | `True` |
45-
| `checkpoint_storage_use_zarr3` | If `True`, uses the Zarr v3 storage format within Orbax, which is optimized for chunked, compressed, N-dimensional arrays. | `boolean` | `True` |
46-
| `checkpoint_storage_concurrent_gb` | Controls the concurrent I/O limit in gigabytes for the checkpointer. Larger models may require increasing this value to avoid I/O bottlenecks. | `integer` | `96` |
47-
| `enable_orbax_v1` | A boolean flag to explicitly enable features and behaviors from Orbax version 1. | `boolean` | `False` |
48-
| `source_checkpoint_layout` | Specifies the format of the checkpoint being **loaded**. This tells the system how to interpret the files at the source path.<br>**Options**: `"orbax"`, `"safetensors"` | `string` | `"orbax"` |
49-
| `checkpoint_conversion_fn` | A user-defined function to process a loaded checkpoint dictionary into a format that the model can understand. This is essential for loading checkpoints from different frameworks or formats (e.g., converting keys from a Hugging Face SafeTensors file). | `function` or `None` | `None` |
41+
| Flag | Description | Type | Default |
42+
| :----------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------- | :------------------ |
43+
| `checkpoint_storage_target_data_file_size_bytes` | Sets a target file size for Orbax to chunk large arrays into smaller physical files. This can dramatically speed up loading over a network and in distributed environments. | `integer` | `2147483648` (2 GB) |
44+
| `checkpoint_storage_use_ocdbt` | If `True`, uses the TensorStore **OCDBT** (Optionally-Cooperative Distributed B+ Tree)) key-value store as the underlying storage format for checkpointing. Set to `0` for Pathways. | `boolean` | `True` |
45+
| `checkpoint_storage_use_zarr3` | If `True`, uses the Zarr v3 storage format within Orbax, which is optimized for chunked, compressed, N-dimensional arrays. Set to `0` for Pathways. | `boolean` | `True` |
46+
| `checkpoint_storage_concurrent_gb` | Controls the concurrent I/O limit in gigabytes for the checkpointer. Larger models may require increasing this value to avoid I/O bottlenecks. | `integer` | `96` |
47+
| `enable_orbax_v1` | A boolean flag to explicitly enable features and behaviors from Orbax version 1. | `boolean` | `False` |
48+
| `source_checkpoint_layout` | Specifies the format of the checkpoint being **loaded**. This tells the system how to interpret the files at the source path.<br>**Options**: `"orbax"`, `"safetensors"` | `string` | `"orbax"` |
49+
| `checkpoint_conversion_fn` | A user-defined function to process a loaded checkpoint dictionary into a format that the model can understand. This is essential for loading checkpoints from different frameworks or formats (e.g., converting keys from a Hugging Face SafeTensors file). | `function` or `None` | `None` |

0 commit comments

Comments
 (0)