Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions docs/guides/checkpointing_solutions/convert_checkpoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,8 @@ export HF_TOKEN=<Hugging Face access token> # your token to access gated HF repo

# -- MaxText configuration --
export MODEL_CHECKPOINT_DIRECTORY=<output directory to store output of checking point> # e.g., gs://my-bucket/my-checkpoint-directory

# -- storage and format options
export USE_ZARR3=<Flag to use zarr3> # Set to True to use zarr3 format (recommended for McJAX); set to False for Pathways.
export USE_OCDBT=<Flag to use ocdbt> # Set to True to use OCDBT format (recommended for McJAX); set to False for Pathways.
export USE_PATHWAYS=0 # Set to 1 for Pathways, 0 for McJAX.

export LAZY_LOAD_TENSORS=<Flag to lazy load> # True to use lazy load, False to use eager load.
```
Expand All @@ -78,21 +76,18 @@ python3 -m maxtext.checkpoint_conversion.to_maxtext \
use_multimodal=false \
hardware=cpu \
skip_jax_distributed_system=true \
checkpoint_storage_use_zarr3=${USE_ZARR3?} \
checkpoint_storage_use_ocdbt=${USE_OCDBT?} \
checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS)) \
checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) \
--lazy_load_tensors=${LAZY_LOAD_TENSORS?}
```

**Key arguments:**

- `model_name`: The model identifier, which should be defined in `src/maxtext/configs/types.py`.
- `scan_layers`: Indicates if the output checkpoint is [scanned](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/reference/core_concepts/checkpoints.md) (scan_layers=true) or unscanned (scan_layers=false).
- `use_multimodal`: Indicates if multimodality is used, important for Gemma3.
- `hf_access_token`: Your Hugging Face token.
- `base_output_directory`: The path where the converted Orbax checkpoint will be stored; it can be Googld Cloud Storage (GCS) or local. If not set, the default output directory is `Maxtext/tmp`.
- `hardware=cpu`: run the conversion script on a CPU machine.
- `checkpoint_storage_use_zarr3`: Set to True to use zarr3 format (recommended for McJAX); set to False for Pathways.
- `checkpoint_storage_use_ocdbt`: Set to True to use OCDBT format (recommended for McJAX); set to False for Pathways.
- `checkpoint_storage_use_zarr3` and `checkpoint_storage_use_ocdbt`: Set to True for McJAX (default, `USE_PATHWAYS=0`); set to False for Pathways (`USE_PATHWAYS=1`). Both are controlled by the `$((1 - USE_PATHWAYS))` calculation in the example above.
- `--lazy_load_tensors` (optional): If `true`, loads Hugging Face weights on-demand to minimize RAM usage. When memory is constrained, it is recommended to use the `--lazy_load_tensors=true` flag to reduce memory usage during conversion. For example, converting a Llama3.1-70B model with `--lazy_load_tensors=true` uses around 200GB of RAM and completes in ~10 minutes.
- `--hf_model_path` (optional): Specifies a local or remote directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/utils.py#L59-L91) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek.

Expand Down
56 changes: 28 additions & 28 deletions docs/guides/checkpointing_solutions/gcs_checkpointing.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,39 +11,39 @@ bucket.

The system follows a specific order when deciding which checkpoint to load at startup. The first valid condition met is the one executed:

1. **Resume Current Run**: If a checkpoint already exists for the current
`run_name`, the system loads the latest fully-saved checkpoint. This is the
default behavior to ensure minimal state loss when resuming after an
interruption.
2. **Load from Specific Path**: The system checks for a user-specified path.
* If `load_parameters_path` is set, we load a parameter only checkpoint from that path..
* If `load_full_state_path` is set, we load a full state checkpoint from that path.
* **Note**: These two options are mutually exclusive and will cause an error if both are set.
3. **Initialize from Scratch**: We don't load a checkpoint and initialize state instead.
1. **Resume Current Run**: If a checkpoint already exists for the current
`run_name`, the system loads the latest fully-saved checkpoint. This is the
default behavior to ensure minimal state loss when resuming after an
interruption.
2. **Load from Specific Path**: The system checks for a user-specified path.
- If `load_parameters_path` is set, we load a parameter only checkpoint from that path..
- If `load_full_state_path` is set, we load a full state checkpoint from that path.
- **Note**: These two options are mutually exclusive and will cause an error if both are set.
3. **Initialize from Scratch**: We don't load a checkpoint and initialize state instead.

### MaxText configuration

| Flag | Description | Type | Default |
| :--- | :--- | :--- | :--- |
| `enable_checkpointing` | A master switch to enable (`True`) or disable (`False`) saving checkpoints during the training run. | `boolean` | `False` |
| `async_checkpointing` | When set to (`True`), this flag makes checkpoint saving asynchronous. The training step is only blocked for the minimal time needed to capture the model's state, and the actual writing to storage happens in a background thread. This is highly recommended for performance. It's enabled by default. | `boolean` | `True` |
| `checkpoint_period` | The interval, in training steps, for how often a checkpoint is saved. | `integer` | `10000` |
| `enable_single_replica_ckpt_restoring` | If `True`, one replica reads the checkpoint from storage and then broadcasts it to all other replicas. This can significantly speed up restoration on multi-host systems by reducing redundant reads from storage.<br>**Note**: This feature is only compatible with training jobs that utilize a Distributed Data Parallel (DDP) strategy. | `boolean` | `False` |
| `load_parameters_path` | Specifies a path to a checkpoint directory to load a parameter only checkpoint.<br>**Example**: `"gs://my-bucket/my-previous-run/checkpoints/items/1000"` | `string` | `""` (disabled) |
| `load_full_state_path` | Specifies a path to a checkpoint directory to load a full checkpoint including optimizer state and step count from a specific directory.<br>**Example**: `"gs://my-bucket/my-interrupted-run/checkpoints/items/500"` | `string` | `""` (disabled) |
| `lora_input_adapters_path` | Specifies a parent directory containing LoRA (Low-Rank Adaptation) adapters. | `string` | `""` (disabled) |
| `force_unroll` | If `True`, unrolls the loop when generating a parameter-only checkpoint. | `boolean` | `False` |
| Flag | Description | Type | Default |
| :------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | :-------- | :-------------- |
| `enable_checkpointing` | A master switch to enable (`True`) or disable (`False`) saving checkpoints during the training run. | `boolean` | `False` |
| `async_checkpointing` | When set to (`True`), this flag makes checkpoint saving asynchronous. The training step is only blocked for the minimal time needed to capture the model's state, and the actual writing to storage happens in a background thread. This is highly recommended for performance. It's enabled by default. | `boolean` | `True` |
| `checkpoint_period` | The interval, in training steps, for how often a checkpoint is saved. | `integer` | `10000` |
| `enable_single_replica_ckpt_restoring` | If `True`, one replica reads the checkpoint from storage and then broadcasts it to all other replicas. This can significantly speed up restoration on multi-host systems by reducing redundant reads from storage.<br>**Note**: This feature is only compatible with training jobs that utilize a Distributed Data Parallel (DDP) strategy. | `boolean` | `False` |
| `load_parameters_path` | Specifies a path to a checkpoint directory to load a parameter only checkpoint.<br>**Example**: `"gs://my-bucket/my-previous-run/checkpoints/items/1000"` | `string` | `""` (disabled) |
| `load_full_state_path` | Specifies a path to a checkpoint directory to load a full checkpoint including optimizer state and step count from a specific directory.<br>**Example**: `"gs://my-bucket/my-interrupted-run/checkpoints/items/500"` | `string` | `""` (disabled) |
| `lora_input_adapters_path` | Specifies a parent directory containing LoRA (Low-Rank Adaptation) adapters. | `string` | `""` (disabled) |
| `force_unroll` | If `True`, unrolls the loop when generating a parameter-only checkpoint. | `boolean` | `False` |

## Storage and format configuration

These settings control the underlying storage mechanism ([Orbax](https://orbax.readthedocs.io)) for performance and compatibility.

| Flag | Description | Type | Default |
| :--- | :--- | :--- | :--- |
| `checkpoint_storage_target_data_file_size_bytes` | Sets a target file size for Orbax to chunk large arrays into smaller physical files. This can dramatically speed up loading over a network and in distributed environments. | `integer` | `2147483648` (2 GB) |
| `checkpoint_storage_use_ocdbt` | If `True`, uses the TensorStore **OCDBT** (Optionally-Cooperative Distributed B+ Tree)) key-value store as the underlying storage format for checkpointing. | `boolean` | `True` |
| `checkpoint_storage_use_zarr3` | If `True`, uses the Zarr v3 storage format within Orbax, which is optimized for chunked, compressed, N-dimensional arrays. | `boolean` | `True` |
| `checkpoint_storage_concurrent_gb` | Controls the concurrent I/O limit in gigabytes for the checkpointer. Larger models may require increasing this value to avoid I/O bottlenecks. | `integer` | `96` |
| `enable_orbax_v1` | A boolean flag to explicitly enable features and behaviors from Orbax version 1. | `boolean` | `False` |
| `source_checkpoint_layout` | Specifies the format of the checkpoint being **loaded**. This tells the system how to interpret the files at the source path.<br>**Options**: `"orbax"`, `"safetensors"` | `string` | `"orbax"` |
| `checkpoint_conversion_fn` | A user-defined function to process a loaded checkpoint dictionary into a format that the model can understand. This is essential for loading checkpoints from different frameworks or formats (e.g., converting keys from a Hugging Face SafeTensors file). | `function` or `None` | `None` |
| Flag | Description | Type | Default |
| :----------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------- | :------------------ |
| `checkpoint_storage_target_data_file_size_bytes` | Sets a target file size for Orbax to chunk large arrays into smaller physical files. This can dramatically speed up loading over a network and in distributed environments. | `integer` | `2147483648` (2 GB) |
| `checkpoint_storage_use_ocdbt` | If `True`, uses the TensorStore **OCDBT** (Optionally-Cooperative Distributed B+ Tree)) key-value store as the underlying storage format for checkpointing. Set to `0` for Pathways. | `boolean` | `True` |
| `checkpoint_storage_use_zarr3` | If `True`, uses the Zarr v3 storage format within Orbax, which is optimized for chunked, compressed, N-dimensional arrays. Set to `0` for Pathways. | `boolean` | `True` |
| `checkpoint_storage_concurrent_gb` | Controls the concurrent I/O limit in gigabytes for the checkpointer. Larger models may require increasing this value to avoid I/O bottlenecks. | `integer` | `96` |
| `enable_orbax_v1` | A boolean flag to explicitly enable features and behaviors from Orbax version 1. | `boolean` | `False` |
| `source_checkpoint_layout` | Specifies the format of the checkpoint being **loaded**. This tells the system how to interpret the files at the source path.<br>**Options**: `"orbax"`, `"safetensors"` | `string` | `"orbax"` |
| `checkpoint_conversion_fn` | A user-defined function to process a loaded checkpoint dictionary into a format that the model can understand. This is essential for loading checkpoints from different frameworks or formats (e.g., converting keys from a Hugging Face SafeTensors file). | `function` or `None` | `None` |
11 changes: 8 additions & 3 deletions docs/tutorials/posttraining/sft_on_multi_host.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,11 @@ export MODEL_CHECKPOINT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-b

**Note:** Make sure that `MODEL_CHECKPOINT_PATH` has the checkpoints created using the correct storage flags:

- **For SFT with McJAX:** `checkpoint_storage_use_zarr3=True` and `checkpoint_storage_use_ocdbt=True`.
- **For SFT with Pathways:** `checkpoint_storage_use_zarr3=False` and `checkpoint_storage_use_ocdbt=False`.
```
export USE_PATHWAYS=0 # Set to 1 for Pathways, 0 for McJAX.
Comment thread
SurbhiJainUSC marked this conversation as resolved.
checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS))
checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS))
```

### Option 2: Converting a Hugging Face checkpoint

Expand Down Expand Up @@ -150,6 +153,8 @@ Once the fine-tuning is completed, you can access your model checkpoints at `$OU
### 6.2. SFT with Pathways

```bash
export USE_PATHWAYS=1

xpk workload create-pathways \
--cluster=${CLUSTER_NAME?} \
--project=${PROJECT?} \
Expand All @@ -158,7 +163,7 @@ xpk workload create-pathways \
--workload=${WORKLOAD_NAME?} \
--tpu-type=${TPU_TYPE?} \
--num-slices=${TPU_SLICE?} \
--command="JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 python3 -m maxtext.trainers.post_train.sft.train_sft run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL_NAME?} load_parameters_path=${MODEL_CHECKPOINT_PATH?} hf_access_token=${HF_TOKEN?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False enable_single_controller=True"
--command="JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 python3 -m maxtext.trainers.post_train.sft.train_sft run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL_NAME?} load_parameters_path=${MODEL_CHECKPOINT_PATH?} hf_access_token=${HF_TOKEN?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS)) checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) enable_single_controller=True"
```

Once the fine-tuning is completed, you can access your model checkpoints at `$OUTPUT_PATH/$WORKLOAD_NAME/checkpoints`.
7 changes: 4 additions & 3 deletions src/maxtext/trainers/post_train/rl/train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,14 @@ def get_maxtext_model(config, devices=None):
Load MaxText model with Tunix adapter.
# Note: pass the path to your scanned checkpoint for 'load_parameters_path'.
# To create a scanned checkpoint, you can use /maxtext/src/MaxText/checkpoint_conversion/to_maxtext.py and if
# using Pathways, please set `checkpoint_storage_use_ocdbt=False checkpoint_storage_use_zarr3=False`
# using Pathways, please set `USE_PATHWAYS=1` and use `$((1 - USE_PATHWAYS))` for storage flags:
# export USE_PATHWAYS=1
# python src/MaxText/checkpoint_conversion/to_maxtext.py \
# --model_name="gemma2-2b" \
# --base_output_directory="/path/to/your/output/directory" \
# --scan_layers=True \
# --checkpoint_storage_use_ocdbt=False\
# checkpoint_storage_use_zarr3=False
# --checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) \
# --checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS))
# Please ensure that you pass the full path ending in `/0/items` for load_parameters_path to train_rl.py i.e.,
# load_parameters_path=/path/to/your/output/directory/0/items
"""
Expand Down
Loading